Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4f285b35
Commit
4f285b35
authored
Jul 17, 2023
by
Tri Dao
Browse files
FlashAttention-2 release
parent
6d48e14a
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
972 additions
and
5012 deletions
+972
-5012
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
+9
-0
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
+9
-0
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
+9
-0
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
+23
-0
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
+19
-0
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
+26
-0
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
+17
-0
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
+23
-0
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+576
-0
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+251
-0
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+0
-211
csrc/flash_attn/src/fmha/gemm.h
csrc/flash_attn/src/fmha/gemm.h
+0
-451
csrc/flash_attn/src/fmha/gmem_tile.h
csrc/flash_attn/src/fmha/gmem_tile.h
+0
-555
csrc/flash_attn/src/fmha/kernel_traits.h
csrc/flash_attn/src/fmha/kernel_traits.h
+0
-116
csrc/flash_attn/src/fmha/mask.h
csrc/flash_attn/src/fmha/mask.h
+0
-90
csrc/flash_attn/src/fmha/smem_tile.h
csrc/flash_attn/src/fmha/smem_tile.h
+0
-1703
csrc/flash_attn/src/fmha/softmax.h
csrc/flash_attn/src/fmha/softmax.h
+0
-607
csrc/flash_attn/src/fmha/utils.h
csrc/flash_attn/src/fmha/utils.h
+0
-1215
csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
.../flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
+0
-64
No files found.
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
224
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim224
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
\ No newline at end of file
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // For dropout there might be a lot of register spilling?
// // These two are very slow due to register spilling
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
// // This one is slightly slower
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
// });
// }
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
);
}
\ No newline at end of file
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
\ No newline at end of file
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// // Using block size (64 x 256) is 27% slower for seqlen=2k
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
// }
// }
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
);
}
\ No newline at end of file
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
\ No newline at end of file
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
0 → 100644
View file @
4f285b35
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
// // This 3rd one is good for H100, and A100, A6000
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
// // These two are always slower
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
// });
// }
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
);
}
\ No newline at end of file
csrc/flash_attn/src/flash_fwd_kernel.h
0 → 100644
View file @
4f285b35
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
#include "philox.cuh"
namespace
flash
{
using
namespace
cute
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMA_M
,
class
...
Args
,
class
TiledMMA
>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
TiledMMA
const
&
tiled_mma
)
{
using
TileShape_MNK
=
typename
TiledMMA
::
TiledShape_MNK
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
constexpr
int
AtomShape_M
=
decltype
(
size
<
0
>
(
AtomShape_MNK
{}))
::
value
;
constexpr
int
kNWarps
=
decltype
(
size
<
0
>
(
TileShape_MNK
{}))
::
value
/
AtomShape_M
;
constexpr
int
MMAStride_M
=
MMA_M
*
AtomShape_M
;
auto
t
=
make_tile
(
Layout
<
Shape
<
Int
<
AtomShape_M
>
,
Int
<
kNWarps
>>
,
Stride
<
_1
,
Int
<
MMAStride_M
>>
>
{},
make_layout
(
size
<
2
>
(
TileShape_MNK
{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return
make_tiled_copy_impl
(
copy_atom
,
tiled_mma
.
get_layoutA_TV
(),
t
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMA_M
,
class
...
Args
,
class
TiledMMA
>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
TiledMMA
const
&
tiled_mma
)
{
using
TileShape_MNK
=
typename
TiledMMA
::
TiledShape_MNK
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
constexpr
int
AtomShape_M
=
decltype
(
size
<
0
>
(
AtomShape_MNK
{}))
::
value
;
constexpr
int
kNWarps
=
decltype
(
size
<
0
>
(
TileShape_MNK
{}))
::
value
/
AtomShape_M
;
constexpr
int
MMAStride_M
=
MMA_M
*
AtomShape_M
;
auto
t
=
make_tile
(
Layout
<
Shape
<
Int
<
AtomShape_M
>
,
Int
<
kNWarps
>>
,
Stride
<
_1
,
Int
<
MMAStride_M
>>
>
{},
// TODO: Shouldn't this be size<1>?
make_layout
(
size
<
2
>
(
TileShape_MNK
{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return
make_tiled_copy_impl
(
copy_atom
,
tiled_mma
.
get_layoutC_TV
(),
t
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
scores
,
Tensor1
&
scores_max
,
Tensor1
&
scores_sum
,
Tensor2
&
acc_o
,
float
softmax_scale_log2
)
{
if
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
scores_max
);
flash
::
scale_apply_exp2
(
scores
,
scores_max
,
softmax_scale_log2
);
flash
::
reduce_sum
(
scores
,
scores_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
scores_max
);
copy
(
scores_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
scores_max
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
scores_max
);
++
mi
)
{
float
scores_max_cur
=
!
Check_inf
?
scores_max
(
mi
)
:
(
scores_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
scores_max
(
mi
));
float
scores_scale
=
exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
scores_sum
(
mi
)
*=
scores_scale
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
}
flash
::
scale_apply_exp2
(
scores
,
scores_max
,
softmax_scale_log2
);
Tensor
scores_sum_cur
=
make_fragment_like
(
scores_sum
);
flash
::
reduce_sum
(
scores
,
scores_sum_cur
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
scores_sum
);
++
mi
)
{
scores_sum
(
mi
)
+=
scores_sum_cur
(
mi
);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
TiledCopy
>
inline
__device__
void
write_softmax_to_gmem
(
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
&
tPgP
,
TiledCopy
gmem_thr_copy_P
)
{
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout
l
=
tOrP
.
layout
();
Tensor
tPrP
=
make_tensor
(
tOrP
.
data
(),
make_layout
(
get
<
0
>
(
l
),
make_layout
(
get
<
1
>
(
l
),
get
<
2
>
(
l
))));
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tPgP
)
==
_1
{});
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tPrP
)
==
size
<
1
>
(
tPgP
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
1
>
(
tPrP
);
++
mi
)
{
copy
(
gmem_thr_copy_P
,
tPrP
(
_
,
mi
),
tPgP
(
_
,
mi
,
0
));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_N
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
(
Is_causal
)
{
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
,
kBlockN
));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
// We move K and V to the last block.
const
index_t
row_offset_k
=
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
const
index_t
row_offset_p
=
((
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q_rounded
+
m_block
*
kBlockM
)
*
params
.
seqlen_k_rounded
+
(
n_block_max
-
1
)
*
kBlockN
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
v_row_stride
,
_1
{}));
Tensor
gP
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
p_ptr
)
+
row_offset_p
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{},
make_stride
(
params
.
seqlen_k_rounded
,
_1
{}));
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutQ
{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
(
Kernel_traits
::
Share_Q_K_smem
?
0
:
size
(
sQ
)),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
auto
gmem_thr_copy_QKV
=
typename
Kernel_traits
::
GmemTiledCopyQKV
{}.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_P
=
typename
Kernel_traits
::
GmemTiledCopyP
{}.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tPgP
=
gmem_thr_copy_P
.
partition_D
(
gP
);
typename
Kernel_traits
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
// (MMA,MMA_M,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
//
// Copy Atom retiling
//
auto
smem_thr_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
).
get_thread_slice
(
tidx
);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto
smem_thr_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
).
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
auto
smem_thr_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
).
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor
scores_max
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
2
*
size
<
1
>
(
acc_o
)
>>
{});
Tensor
scores_sum
=
make_fragment_like
(
scores_max
);
//
// PREDICATES
//
// // Allocate predicate tensors for m and n
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
// Construct identity layout for sQ and sK
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) {
// print(tScQ.layout()); printf("\n");
// for (int i = 0; i < size(tScQ); ++i) {
// printf("%d ", get<0>(tScQ(i)));
// }
// printf("\n");
// for (int i = 0; i < size(tScQ); ++i) {
// printf("%d ", get<1>(tScQ(i)));
// }
// printf("\n");
// }
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_QKV
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tKVcKV
=
gmem_thr_copy_QKV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
// Set predicates for k bounds
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tQpQ
);
++
k
)
{
tQpQ
(
k
)
=
get
<
1
>
(
tQcQ
(
0
,
0
,
k
))
<
params
.
d
;
}
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tKVpKV
);
++
k
)
{
tKVpKV
(
k
)
=
get
<
1
>
(
tKVcKV
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Prologue
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
>
(
gmem_thr_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
// // Copy rmem to smem
// // copy(tQrQ, tQsQ);
// flash::cp_async_wait<0>();
// __syncthreads();
// // if (cute::thread(1, 0)) { print(tQsQ); }
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
// // if (cute::thread0()) { print(sQNoSwizzle); }
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
copy
(
smem_thr_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
__syncthreads
();
}
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_N
,
Is_even_K
>
(
gmem_thr_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
if
(
Kernel_traits
::
Is_Q_in_regs
&&
!
Kernel_traits
::
Share_Q_K_smem
)
{
flash
::
cp_async_wait
<
1
>
();
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
copy
(
smem_thr_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
unsigned
long
long
seed
=
std
::
get
<
0
>
(
seeds
);
unsigned
long
long
offset
=
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
clear
(
acc_o
);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
constexpr
int
n_masking_steps
=
Is_causal
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
1
;
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
// Advance gV
if
(
masking_step
>
0
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_thr_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_thr_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
// if (cute::thread0()) { print(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_even_N
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
// static_assert(decltype(size<0>(taccScS))::value == 4);
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
// Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout()));
// flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM);
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
}
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
n_block
>
0
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_thr_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
}
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
uint32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
uint32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
copy
(
tOrP
,
tOrP_copy
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP_copy
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_thr_copy_P
);
tPgP
.
data
()
=
tPgP
.
data
()
+
(
-
kBlockN
);
}
if
(
Is_dropout
)
{
flash
::
apply_dropout
(
tOrP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
);
}
// if (cute::thread0()) { print(tOrP); }
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
if
(
n_masking_steps
>
1
&&
n_block
<=
0
)
{
--
n_block
;
break
;
}
}
// These are the iterations where we don't need masking on S
for
(;
n_block
>=
0
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
// Advance gV
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_thr_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
n_block
>
0
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_thr_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
softmax_rescale_o
<
/*Is_first=*/
false
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
uint32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
uint32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
copy
(
tOrP
,
tOrP_copy
);
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
tOrP_copy
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
);
flash
::
write_softmax_to_gmem
(
tOrP_copy
,
tPgP
,
gmem_thr_copy_P
);
tPgP
.
data
()
=
tPgP
.
data
()
+
(
-
kBlockN
);
}
if
(
Is_dropout
)
{
flash
::
apply_dropout
(
tOrP
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
block_row_idx
,
block_col_idx
,
kNWarps
);
}
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_thr_copy_V
);
}
// Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
Tensor
lse
=
make_fragment_like
(
scores_sum
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
scores_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
INFINITY
:
scores_max
(
mi
)
*
params
.
scale_softmax
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
params
.
rp_dropout
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
// if (cute::thread0()) { print(acc_o_rowcol); }
// Convert acc_o from fp32 to fp16/bf16
Tensor
rO
=
flash
::
convert_type
<
Element
>
(
acc_o
);
Tensor
sO
=
make_tensor
(
sQ
.
data
(),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
auto
smem_thr_copy_O
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomO
{},
tiled_mma
).
get_thread_slice
(
tidx
);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor
taccOrO
=
smem_thr_copy_O
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here.
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
__syncthreads
();
}
copy
(
smem_thr_copy_O
,
taccOrO
,
taccOsO
);
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
auto
gmem_thr_copy_O
=
typename
Kernel_traits
::
GmemTiledCopyO
{}.
get_thread_slice
(
tidx
);
Tensor
tOsO
=
gmem_thr_copy_O
.
partition_S
(
sO
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
__syncthreads
();
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
copy
(
gmem_thr_copy_O
,
tOsO
,
tOrO
);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma
.
partition_C
(
caccO
);
// (MMA,MMA_M,MMA_K)
static_assert
(
decltype
(
size
<
0
>
(
taccOcO
))
::
value
==
4
);
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
Tensor
taccOcO_row
=
logical_divide
(
taccOcO
,
Shape
<
_2
>
{})(
make_coord
(
0
,
_
),
_
,
0
);
CUTE_STATIC_ASSERT_V
(
size
(
lse
)
==
size
(
taccOcO_row
));
// MMA_M
if
(
get
<
1
>
(
taccOcO_row
(
0
))
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccOcO_row
(
mi
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
)
{
gLSE
(
row
)
=
lse
(
mi
);
}
}
}
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sO
),
size
<
1
>
(
sO
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_O
.
partition_D
(
cO
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgO
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_thr_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
z
;
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
// them to have the same number of threads or have to traverse the attention matrix
// in the same order.
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
// (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_N
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
csrc/flash_attn/src/flash_fwd_launch_template.h
0 → 100644
View file @
4f285b35
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_N
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
// printf("smem_size = %d\n", smem_size);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
b
,
params
.
h
);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const
bool
is_even_N
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
kNThreads
,
smem_size
);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
32
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
64
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
96
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// These two are always slower
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
128
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 1st ones are good for H100, A100
// 2nd one is good for A6000 bc we get slightly better occupancy
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
160
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal.
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
192
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
224
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
))
{
// 112 KB
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
// If we have N = 32, there are only 1024 elements to load at once, where each load
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
256
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_sm
,
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_sm
,
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device
);
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
)
&&
max_smem_per_sm
<
4
*
Headdim
*
(
64
+
2
*
64
))
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
csrc/flash_attn/src/fmha.h
deleted
100644 → 0
View file @
6d48e14a
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/UnpackRaw.cuh>
#include <fmha_utils.h>
constexpr
int
TOTAL_DIM
=
0
;
constexpr
int
H_DIM
=
1
;
constexpr
int
D_DIM
=
2
;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Qkv_params
{
// The QKV matrices.
void
*
__restrict__
q_ptr
;
void
*
__restrict__
k_ptr
;
void
*
__restrict__
v_ptr
;
// The stride between rows of the Q, K and V matrices.
// size_t qkv_stride_in_elts;
// size_t qkv_stride_in_bytes;
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t
q_row_stride_in_elts
;
uint32_t
k_row_stride_in_elts
;
uint32_t
v_row_stride_in_elts
;
uint32_t
q_head_stride_in_elts
;
uint32_t
k_head_stride_in_elts
;
uint32_t
v_head_stride_in_elts
;
// The number of heads.
int
h
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
FMHA_fprop_params
:
public
Qkv_params
{
// The O matrix (output).
void
*
__restrict__
o_ptr
;
// The stride between rows of O.
// size_t o_stride_in_elts;
// size_t o_stride_in_bytes;
uint32_t
o_row_stride_in_elts
;
uint32_t
o_head_stride_in_elts
;
uint32_t
o_tmp_row_stride_in_elts
;
uint32_t
o_tmp_head_stride_in_elts
;
// The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop;
void
*
__restrict__
o_tmp_ptr
;
// The pointer to the S matrix.
void
*
__restrict__
s_ptr
;
// The stride between rows of the S matrix.
// int64_t s_stride_in_bytes;
uint32_t
s_stride_in_bytes
;
// The pointer to the softmax sum.
void
*
__restrict__
softmax_lse_ptr
;
// The dimensions.
int
b
,
seqlen_q
,
seqlen_k
,
d
;
// The scaling factors for the kernel.
float
scale_bmm1f
;
uint32_t
scale_bmm1
;
// array of length b+1 holding starting offset of each sequence.
int
*
__restrict__
cu_seqlens_q
;
int
*
__restrict__
cu_seqlens_k
;
int
*
__restrict__
blockmask
;
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
uint32_t
p_dropout_in_uint
;
uint16_t
p_dropout_in_uint16_t
;
// Scale factor of 1 / (1 - p_dropout).
float
rp_dropout
;
float
scale_bmm1_rp_dropout
;
// Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t
scale_dropout
;
// Random state.
at
::
PhiloxCudaState
philox_args
;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t
*
rng_state
;
bool
is_bf16
;
bool
is_causal
;
int
num_splits
;
// How many SMs per attention matrix.
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
FMHA_dgrad_params
:
public
FMHA_fprop_params
{
// The dQKV matrices.
void
*
__restrict__
dq_ptr
;
void
*
__restrict__
dk_ptr
;
void
*
__restrict__
dv_ptr
;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q dimension
// void *__restrict__ dk_accum_ptr;
// void *__restrict__ dv_accum_ptr;
// The stride between rows of the dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t
dq_row_stride_in_elts
;
uint32_t
dk_row_stride_in_elts
;
uint32_t
dv_row_stride_in_elts
;
uint32_t
dq_head_stride_in_elts
;
uint32_t
dk_head_stride_in_elts
;
uint32_t
dv_head_stride_in_elts
;
// The dO matrix. We assume it is contiguous.
void
*
__restrict__
do_ptr
;
// The pointer to the softmax d sum.
void
*
__restrict__
dsoftmax_sum
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_params
>
struct
Launch_params
{
Launch_params
(
cudaDeviceProp
*
props_
,
cudaStream_t
stream_
,
bool
is_dropout_
,
bool
return_softmax_
)
:
elts_per_thread
(
0
)
,
props
(
props_
)
,
stream
(
stream_
)
,
is_dropout
(
is_dropout_
)
,
return_softmax
(
return_softmax_
)
{
}
size_t
elts_per_thread
;
cudaDeviceProp
*
props
;
cudaStream_t
stream
;
bool
is_dropout
;
bool
return_softmax
;
Kernel_params
params
;
int
num_full_heads
;
int
num_main_groups
;
int
heads_last_wave
;
int
main_steps
;
int
rest_steps
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void
run_fmha_fwd_hdim32
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_fwd_hdim64
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_fwd_hdim128
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_bwd_hdim32
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
void
run_fmha_bwd_hdim64
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
void
run_fmha_bwd_hdim128
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
void
run_fmha_block_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_block_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/fmha/gemm.h
deleted
100644 → 0
View file @
6d48e14a
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/layout/layout.h"
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type_
,
int
NUM_ELTS_
,
int
BITS_PER_ELT_
,
int
ALIGNMENT_
>
struct
Fragment_base_
{
// The data type.
using
Data_type
=
Data_type_
;
// default input type
using
Input_type_
=
Data_type_
;
// Does it store the array of elements.
static
constexpr
bool
HAS_ELTS
=
BITS_PER_ELT_
>=
8
;
// The number of elements.
static
constexpr
int
NUM_ELTS
=
NUM_ELTS_
;
// The size of element in bits.
static
constexpr
int
BITS_PER_ELT
=
BITS_PER_ELT_
;
// The size of byte of a single register.
static
constexpr
int
BYTES_PER_REG
=
4
;
// The size in bits.
static
constexpr
int
BITS_PER_REG
=
BYTES_PER_REG
*
8
;
// The number of registers needed to store the fragment.
static
constexpr
int
NUM_REGS
=
DivUpConstexpr
(
NUM_ELTS
*
BITS_PER_ELT
,
BITS_PER_REG
);
// The size in bytes (as returned by sizeof(Fragment_base<>).
static
constexpr
int
SIZE_IN_BYTES
=
NUM_REGS
*
BYTES_PER_REG
;
// The alignment.
static
constexpr
int
ALIGNMENT
=
ALIGNMENT_
>
0
?
ALIGNMENT_
:
MinConstexpr
(
NUM_REGS
*
BYTES_PER_REG
,
16
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The type of the elements.
typename
Data_type_
,
// The number of elements.
int
NUM_ELTS_
,
// The alignment if you want to force a value -- use 0 otherwise.
int
ALIGNMENT_
=
0
,
// The base class.
typename
Base_
=
Fragment_base_
<
Data_type_
,
NUM_ELTS_
,
8
*
sizeof
(
Data_type_
),
ALIGNMENT_
>
>
struct
alignas
(
static_cast
<
int
>
(
Base_
::
ALIGNMENT
))
Fragment
:
public
Base_
{
// The size of a load/store.
static
constexpr
int
BYTES_PER_LOAD_STORE
=
Base_
::
NUM_REGS
*
sizeof
(
uint32_t
);
// Clear the fragment. Using PTX in that code seems to produce better SASS...
inline
__device__
void
clear
()
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Base_
::
NUM_REGS
;
++
ii
)
{
asm
volatile
(
"mov.u32 %0, 0;
\n
"
:
"=r"
(
this
->
reg
(
ii
))
:
);
}
}
// Immutable access to a register.
inline
__device__
const
uint32_t
&
reg
(
int
ii
)
const
{
return
this
->
regs_
[
ii
];
}
// Mutable access to a register.
inline
__device__
uint32_t
&
reg
(
int
ii
)
{
return
this
->
regs_
[
ii
];
}
uint32_t
regs_
[
Base_
::
NUM_REGS
];
// Immutable access to the elements.
inline
__device__
const
Data_type_
&
elt
(
int
ii
)
const
{
return
reinterpret_cast
<
const
Data_type_
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Mutable access to the elements.
inline
__device__
Data_type_
&
elt
(
int
ii
)
{
return
reinterpret_cast
<
Data_type_
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Immutable access to the elements with a cast.
template
<
typename
Cast_type
>
inline
__device__
const
Cast_type
&
elt_as
(
int
ii
)
const
{
return
reinterpret_cast
<
const
Cast_type
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Mutable access to the elements.
template
<
typename
Cast_type
>
inline
__device__
Cast_type
&
elt_as
(
int
ii
)
{
return
reinterpret_cast
<
Cast_type
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Add another fragment.
inline
__device__
void
add
(
const
Fragment
&
other
)
{
// TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
// Also are we doing int addition or __half2 addition?
#pragma unroll
for
(
int
ii
=
0
;
ii
<
NUM_ELTS_
;
++
ii
)
{
this
->
elt
(
ii
)
+=
other
.
elt
(
ii
);
}
}
// Multiply by another fragment.
inline
__device__
void
hmul
(
const
Fragment
&
other
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Base_
::
NUM_REGS
;
++
ii
)
{
this
->
reg
(
ii
)
=
fmha
::
hmul2
(
this
->
reg
(
ii
),
other
.
reg
(
ii
));
}
}
template
<
typename
elem_type
>
inline
__device__
void
hrelu_
()
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Base_
::
NUM_REGS
;
++
ii
)
{
this
->
reg
(
ii
)
=
fmha
::
hrelu2
<
elem_type
>
(
this
->
reg
(
ii
));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Layout
>
struct
Fragment_a
:
public
Fragment
<
uint16_t
,
8
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Layout
>
struct
Fragment_b
:
public
Fragment
<
uint16_t
,
8
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Fragment_accumulator
:
public
Fragment
<
float
,
8
>
{
// The base class.
using
Base
=
Fragment
<
float
,
8
>
;
// Add two fragments.
template
<
typename
Other_fragment_
>
inline
__device__
void
add
(
const
Other_fragment_
&
other
)
{
for
(
int
ii
=
0
;
ii
<
Base
::
NUM_ELTS
;
++
ii
)
{
this
->
elt
(
ii
)
=
this
->
elt
(
ii
)
+
other
.
elt
(
ii
);
}
}
inline
__device__
void
mul_
(
const
float
other
)
{
for
(
int
ii
=
0
;
ii
<
Base
::
NUM_ELTS
;
++
ii
)
{
this
->
elt
(
ii
)
*=
other
;
}
}
// Do the HMMA.
template
<
typename
Layout_a
,
typename
Layout_b
>
inline
__device__
void
mma
(
const
Fragment_a
<
Layout_a
>
&
a
,
const
Fragment_b
<
Layout_b
>
&
b
)
{
asm
volatile
(
\
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
\n
"
\
" {%0, %1, %2, %3},
\n
"
\
" {%4, %5, %6, %7},
\n
"
\
" {%8, %9},
\n
"
\
" {%0, %1, %2, %3};
\n
"
\
:
"+f"
(
elt
(
0
)),
"+f"
(
elt
(
1
)),
"+f"
(
elt
(
2
)),
"+f"
(
elt
(
3
))
:
"r"
(
a
.
reg
(
0
)),
"r"
(
a
.
reg
(
1
)),
"r"
(
a
.
reg
(
2
)),
"r"
(
a
.
reg
(
3
))
,
"r"
(
b
.
reg
(
0
)),
"r"
(
b
.
reg
(
1
)));
asm
volatile
(
\
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
\n
"
\
" {%0, %1, %2, %3},
\n
"
\
" {%4, %5, %6, %7},
\n
"
\
" {%8, %9},
\n
"
\
" {%0, %1, %2, %3};
\n
"
\
:
"+f"
(
elt
(
4
)),
"+f"
(
elt
(
5
)),
"+f"
(
elt
(
6
)),
"+f"
(
elt
(
7
))
:
"r"
(
a
.
reg
(
0
)),
"r"
(
a
.
reg
(
1
)),
"r"
(
a
.
reg
(
2
)),
"r"
(
a
.
reg
(
3
))
,
"r"
(
b
.
reg
(
2
)),
"r"
(
b
.
reg
(
3
)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Fragment
,
int
M
,
int
N
>
inline
__device__
void
clear
(
Fragment
(
&
frag
)[
M
][
N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
frag
[
mi
][
ni
].
clear
();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Accumulator_type
,
int
WARPS_K
>
struct
Clear_accumulator
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_K
>
struct
Clear_accumulator
<
float
,
WARPS_K
>
{
template
<
typename
Acc
,
int
M
,
int
N
>
static
inline
__device__
void
apply
(
Acc
(
&
acc
)[
M
][
N
],
bool
=
false
)
{
fmha
::
clear
(
acc
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
inline
__device__
void
gemm
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
acc
[
mi
][
ni
].
mma
(
a
[
mi
],
b
[
ni
]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically maps half types => cutlass data types
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Type_
>
struct
HalfTypeToCutlassType
{
using
Type
=
Type_
;
};
/// Statically maps __half => cutlass::half_t
template
<
>
struct
HalfTypeToCutlassType
<
__half
>
{
using
Type
=
cutlass
::
half_t
;
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
template
<
>
struct
HalfTypeToCutlassType
<
__nv_bfloat16
>
{
using
Type
=
cutlass
::
bfloat16_t
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
elem_type
,
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
inline
__device__
void
gemm_cl
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
using
Shape
=
cutlass
::
gemm
::
GemmShape
<
16
*
M
,
16
*
N
,
16
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
#else
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
4
>
;
// TD [2022-06-02] We don't support Volta (SM70) yet.
assert
(
0
);
#endif
using
Element
=
typename
HalfTypeToCutlassType
<
elem_type
>::
Type
;
using
ElementC
=
float
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
WarpMma
=
typename
cutlass
::
gemm
::
warp
::
DefaultMmaTensorOp
<
Shape
,
InstructionShape
,
Element
,
LayoutA
,
Element
,
LayoutB
,
ElementC
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
arch
::
OpMultiplyAdd
,
1
,
true
>::
Type
;
constexpr
int
kIters
=
Shape
::
kK
/
InstructionShape
::
kK
;
// using FragmentA = typename WarpMma::FragmentA;
// using FragmentB = typename WarpMma::FragmentB;
using
FragmentA
=
typename
WarpMma
::
ArchMmaOperator
::
FragmentA
;
using
FragmentB
=
typename
WarpMma
::
ArchMmaOperator
::
FragmentB
;
using
FragmentC
=
typename
WarpMma
::
FragmentC
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
// printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
// printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
// printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
// printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
// printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
// printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
// }
// static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
// static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
static_assert
(
FragmentA
::
kStorageElements
*
kIters
==
a
[
0
].
NUM_REGS
);
static_assert
(
FragmentB
::
kStorageElements
*
kIters
*
16
/
InstructionShape
::
kN
==
b
[
0
].
NUM_REGS
);
static_assert
(
FragmentC
::
kStorageElements
==
M
*
N
*
acc
[
0
][
0
].
NUM_REGS
);
// const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
// const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
FragmentC
c_cl
=
reinterpret_cast
<
FragmentC
(
&
)
>
(
acc
);
FragmentA
a_cl
[
kIters
][
M
];
FragmentA
b_cl
[
kIters
][
N
];
constexpr
int
kRegs
=
InstructionShape
::
kK
==
16
?
4
:
2
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
kIters
;
iter
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
uint32_t
*
a_ptr
=
a_cl
[
iter
][
mi
].
raw_data
();
#pragma unroll
for
(
int
ki
=
0
;
ki
<
kRegs
;
ki
++
)
{
a_ptr
[
ki
]
=
a
[
mi
].
regs_
[
iter
*
kRegs
+
ki
];
}
}
}
#pragma unroll
for
(
int
iter
=
0
;
iter
<
kIters
;
iter
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint32_t
*
b_ptr
=
b_cl
[
iter
][
ni
].
raw_data
();
#pragma unroll
for
(
int
ki
=
0
;
ki
<
kRegs
;
ki
++
)
{
// b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
// TD [2022-06-02] For some reason the order for frag_b is different.
b_ptr
[
ki
]
=
b
[
ni
].
regs_
[
InstructionShape
::
kK
==
16
?
iter
*
kRegs
+
ki
:
ki
*
kRegs
+
iter
];
}
}
}
WarpMma
mma_op
;
// mma_op(c_cl, a_cl, b_cl, c_cl);
#pragma unroll
for
(
int
iter
=
0
;
iter
<
kIters
;
iter
++
)
{
mma_op
(
c_cl
,
reinterpret_cast
<
const
typename
WarpMma
::
FragmentA
(
&
)
>
(
a_cl
[
iter
]),
reinterpret_cast
<
const
typename
WarpMma
::
FragmentB
(
&
)
>
(
b_cl
[
iter
]),
c_cl
);
}
// The modified c_cl is not copied back into acc, idk why
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
acc
[
mi
][
ni
].
elt
(
i
)
=
c_cl
[
mi
*
N
*
8
+
ni
*
8
+
i
];
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The number of rows in the CTA tile.
int
M_
,
// The number of cols in the CTA tile.
int
N_
,
// The number of elements in the the K dimension of the GEMM loop.
int
K_
,
// The number of rows of warps.
int
WARPS_M_
,
// The number of cols of warps.
int
WARPS_N_
,
// The number of warps in the K dimension of the GEMM loop.
int
WARPS_K_
>
struct
Cta_tile_
{
static
constexpr
int
M
=
M_
,
N
=
N_
,
K
=
K_
;
// The number of warps.
static
constexpr
int
WARPS_M
=
WARPS_M_
,
WARPS_N
=
WARPS_N_
,
WARPS_K
=
WARPS_K_
;
// The number of warps per CTA.
static
constexpr
int
WARPS_PER_CTA
=
WARPS_M
*
WARPS_N
*
WARPS_K
;
// The number of threads per warp.
static
constexpr
int
THREADS_PER_WARP
=
32
;
// The number of threads per CTA.
static
constexpr
int
THREADS_PER_CTA
=
WARPS_PER_CTA
*
THREADS_PER_WARP
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Hmma_tile
{
// The number of elements computed with a single warp-MMA.
static
constexpr
int
M_PER_MMA
=
16
,
N_PER_MMA
=
16
,
K_PER_MMA
=
16
;
// The number of elements computed with a single CTA-MMA.
static
constexpr
int
M_PER_MMA_PER_CTA
=
M_PER_MMA
*
Cta_tile
::
WARPS_M
,
N_PER_MMA_PER_CTA
=
N_PER_MMA
*
Cta_tile
::
WARPS_N
,
K_PER_MMA_PER_CTA
=
K_PER_MMA
*
Cta_tile
::
WARPS_K
;
// The number of MMAs needed to compute the GEMM.
static
constexpr
int
MMAS_M
=
DivUpConstexpr
(
Cta_tile
::
M
,
M_PER_MMA_PER_CTA
),
MMAS_N
=
DivUpConstexpr
(
Cta_tile
::
N
,
N_PER_MMA_PER_CTA
),
MMAS_K
=
DivUpConstexpr
(
Cta_tile
::
K
,
K_PER_MMA_PER_CTA
);
// // The number of elements computed per warp.
// static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA,
// N_PER_WARP = MMAS_N * N_PER_MMA,
// K_PER_WARP = MMAS_K * K_PER_MMA;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using
A_type
=
uint16_t
;
using
B_type
=
uint16_t
;
using
C_type
=
uint16_t
;
using
Accumulator_type
=
float
;
using
Epilogue_type
=
float
;
constexpr
int
BITS_PER_ELEMENT_A
=
sizeof
(
A_type
)
*
8
;
constexpr
int
BITS_PER_ELEMENT_B
=
sizeof
(
B_type
)
*
8
;
constexpr
int
BITS_PER_ELEMENT_C
=
sizeof
(
C_type
)
*
8
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
int
N
,
int
K
,
int
WARPS_M
,
int
WARPS_N
,
int
WARPS_K
>
using
Cta_tile_extd
=
Cta_tile_
<
M
,
N
,
K
,
WARPS_M
,
WARPS_N
,
WARPS_K
>
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile_
>
using
Cta_tile_with_k_with_padding
=
Cta_tile_extd
<
Cta_tile_
::
M
,
Cta_tile_
::
N
,
Next_power_of_two
<
Cta_tile_
::
K
>::
VALUE
,
Cta_tile_
::
WARPS_M
,
Cta_tile_
::
WARPS_N
,
Cta_tile_
::
WARPS_K
>
;
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
csrc/flash_attn/src/fmha/gmem_tile.h
deleted
100644 → 0
View file @
6d48e14a
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <fmha/utils.h>
namespace
fmha
{
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile_
,
// The number of bits per element.
int
BITS_PER_ELEMENT
,
// The number of rows of Q, K or V loaded by this tile.
int
ROWS_
,
// The number of columns.
int
COLS
,
int
BYTES_PER_LDGS_
=
16
>
struct
Gmem_tile_qkv
{
using
Cta_tile
=
Cta_tile_
;
static
constexpr
int
BYTES_PER_ELEMENT
=
BITS_PER_ELEMENT
/
8
;
// The size of each LDG.
static
constexpr
int
BYTES_PER_LDG
=
BYTES_PER_LDGS_
;
// The size of a row in bytes.
static
constexpr
int
BYTES_PER_ROW
=
COLS
*
BITS_PER_ELEMENT
/
8
;
// The number of threads to load a "row" of the matrix.
static
constexpr
int
THREADS_PER_ROW
=
BYTES_PER_ROW
/
BYTES_PER_LDG
;
static
constexpr
int
ROWS
=
ROWS_
;
// The number of "rows" loaded per LDG.
static
constexpr
int
ROWS_PER_LDG
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
;
// The number of LDGs needed to load a chunk of the Q matrix.
static
constexpr
int
LDGS
=
DivUpConstexpr
(
ROWS
,
ROWS_PER_LDG
);
// Ctor.
template
<
typename
BInfo
>
inline
__device__
Gmem_tile_qkv
(
void
*
ptr_
,
const
uint32_t
row_stride_in_elts
,
const
uint32_t
head_stride_in_elts
,
const
int
headdim
,
const
BInfo
&
binfo
,
const
int
tidx
,
bool
use_seqlen_q
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
,
actual_seqlen
(
use_seqlen_q
?
binfo
.
actual_seqlen_q
:
binfo
.
actual_seqlen_k
)
,
ptr
(
reinterpret_cast
<
char
*>
(
ptr_
))
,
tidx_
(
tidx
)
,
col_predicate
((
tidx
%
THREADS_PER_ROW
)
*
(
BYTES_PER_LDG
/
BYTES_PER_ELEMENT
)
<
headdim
)
{
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
THREADS_PER_ROW
;
// Compute the position of the thread in the row.
int
col
=
tidx
%
THREADS_PER_ROW
;
// Store the row as we need it to disable the loads.
// TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it
// row_ = row;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t
row_offset
=
(
uint32_t
)(((
use_seqlen_q
?
binfo
.
sum_s_q
:
binfo
.
sum_s_k
)
+
row
)
*
row_stride_in_bytes
);
// Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset
+=
(
uint32_t
)(
binfo
.
bidh
*
head_stride_in_elts
*
BYTES_PER_ELEMENT
);
// Assemble the final pointer.
ptr
+=
row_offset
+
col
*
BYTES_PER_LDG
;
}
// Store data to shared memory.
template
<
typename
Smem_tile
>
inline
__device__
void
commit
(
Smem_tile
&
smem_tile
)
{
smem_tile
.
store
(
fetch_
);
}
inline
__device__
void
load
()
{
int
row_
=
tidx_
/
THREADS_PER_ROW
;
const
void
*
ptrs
[
LDGS
];
uint32_t
preds
[
LDGS
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
ptrs
[
ii
]
=
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
;
preds
[
ii
]
=
col_predicate
&&
((
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
));
fetch_
[
ii
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
Ldg_functor
<
uint4
,
LDGS
>
fct
(
fetch_
,
ptrs
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
fct
.
load
(
ii
,
preds
[
ii
]);
}
}
// Store data to memory.
inline
__device__
void
store
(
const
uint4
(
&
data
)[
LDGS
])
{
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
char
*
ptr_
=
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row_stride_in_bytes
;
if
(
col_predicate
&&
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
))
{
fmha
::
stg
(
ptr_
,
data
[
ii
]);
}
}
}
inline
__device__
void
move
(
const
int
steps
=
1
)
{
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr
+=
(
uint32_t
)
ROWS
*
row_stride_in_bytes
*
steps
;
actual_seqlen
-=
ROWS
*
steps
;
}
// The stride between rows for the QKV matrice.
// int64_t row_stride_in_bytes;
const
uint32_t
row_stride_in_bytes
;
// The pointer.
char
*
ptr
;
// The fetch registers.
uint4
fetch_
[
LDGS
];
// Keep track of the row the thread is processing as we move the tile.
// int row_;
const
int
tidx_
;
// The length of the sequence loaded by that memory tile.
int
actual_seqlen
;
const
bool
col_predicate
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
int
BYTES_PER_ELEMENT
=
2
>
struct
Gmem_tile_o
{
static_assert
(
BYTES_PER_ELEMENT
==
2
||
BYTES_PER_ELEMENT
==
4
);
// The mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The size of each element.
// static constexpr int BYTES_PER_ELEMENT = 2;
// The size of each STG.
static
constexpr
int
BYTES_PER_STG
=
BYTES_PER_ELEMENT
*
4
;
static
constexpr
int
COLS
=
Cta_tile
::
N
;
// The size of a row in bytes.
static
constexpr
int
BYTES_PER_ROW
=
COLS
*
BYTES_PER_ELEMENT
;
// The number of threads to store a "row" of the matrix.
static
constexpr
int
THREADS_PER_ROW
=
BYTES_PER_ROW
/
BYTES_PER_STG
;
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
static
constexpr
int
ROWS
=
Cta_tile
::
M
;
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
static
constexpr
int
ROWS_PER_LOOP
=
ROWS
<=
64
?
ROWS
:
(
int
)
Mma_tile
::
M_PER_MMA_PER_CTA
;
// The number of outter loop for the stores.
static
constexpr
int
LOOPS
=
ROWS
/
ROWS_PER_LOOP
;
// The number of "rows" stored per STG.
static
constexpr
int
ROWS_PER_STG
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
;
// Do we have to guard against partial writes/reads.
static
constexpr
bool
HAS_INCOMPLETE_STG
=
Cta_tile
::
M
%
ROWS_PER_STG
!=
0
;
// The number of STGs needed to store a chunk of the Q matrix.
static
constexpr
int
STGS_PER_LOOP
=
DivUpConstexpr
(
ROWS_PER_LOOP
,
ROWS_PER_STG
);
// The number of STGs needed to store a chunk of the Q matrix in total.
static
constexpr
int
STGS
=
STGS_PER_LOOP
*
LOOPS
;
// Ctor.
template
<
typename
BInfo
>
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
inline
__device__
Gmem_tile_o
(
void
*
ptr
,
const
uint32_t
row_stride_in_elts
,
const
uint32_t
head_stride_in_elts
,
const
int
headdim
,
const
BInfo
&
binfo
,
const
int
tidx
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
,
actual_seqlen_q
(
binfo
.
actual_seqlen_q
)
,
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
))
,
tidx_
(
tidx
)
,
col_predicate
((
tidx
%
THREADS_PER_ROW
)
*
(
BYTES_PER_STG
/
BYTES_PER_ELEMENT
)
<
headdim
)
{
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
THREADS_PER_ROW
;
// Compute the position of the thread in the row.
int
col
=
tidx
%
THREADS_PER_ROW
;
// Store the row as we need it to disable loads.
// row_ = row;
// The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
uint32_t
row_offset
=
(
uint32_t
)((
binfo
.
sum_s_q
+
row
)
*
row_stride_in_bytes
);
row_offset
+=
(
uint32_t
)(
binfo
.
bidh
*
head_stride_in_elts
*
BYTES_PER_ELEMENT
);
// Assemble the final pointer.
ptr_
+=
row_offset
+
col
*
BYTES_PER_STG
;
// Is that thread active on the last STG?
if
(
HAS_INCOMPLETE_STG
)
{
is_active_for_last_stg_
=
row
+
(
STGS
-
1
)
*
ROWS_PER_STG
<
Cta_tile
::
M
;
}
}
// Store data to global memory.
template
<
typename
elem_type
=
__half
>
inline
__device__
void
store
(
const
uint4
(
&
src
)[
STGS_PER_LOOP
],
int
mi
)
{
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
((
!
col_predicate
)
||
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_q
))
{
break
;
}
if
(
BYTES_PER_ELEMENT
==
4
)
{
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
stg
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_stride_in_bytes
,
src
[
ii
]);
}
}
else
if
(
BYTES_PER_ELEMENT
==
2
)
{
float
x
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
x
);
float
y
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
y
);
float
z
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
z
);
float
w
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
w
);
uint2
out
=
fmha
::
float4_pack
<
elem_type
>
(
x
,
y
,
z
,
w
);
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
stg
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_stride_in_bytes
,
out
);
}
}
}
}
// Store data to global memory with atomicAdd.
inline
__device__
void
atomic_add
(
const
uint4
(
&
src
)[
STGS_PER_LOOP
],
int
mi
)
{
static_assert
(
BYTES_PER_ELEMENT
==
4
);
// Only do atomic add on floats
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
((
!
col_predicate
)
||
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_q
))
{
break
;
}
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
float
*
ptr_
=
reinterpret_cast
<
float
*>
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_stride_in_bytes
);
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
atomicAdd
(
ptr_
+
jj
,
reinterpret_cast
<
const
float
(
&
)[
4
]
>
(
src
[
ii
])[
jj
]);
}
}
}
}
// Load data from global memory.
inline
__device__
void
load
(
uint4
(
&
dst
)[
STGS_PER_LOOP
],
int
mi
)
{
static_assert
(
BYTES_PER_ELEMENT
==
4
);
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
((
!
col_predicate
)
||
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_q
))
{
break
;
}
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
ldg
(
dst
[
ii
],
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_stride_in_bytes
);
}
}
}
inline
__device__
void
move
(
const
int
steps
=
1
)
{
// row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr_
+=
(
uint32_t
)
ROWS
*
row_stride_in_bytes
*
steps
;
actual_seqlen_q
-=
ROWS
*
steps
;
}
// The stride between rows for the QKV matrice.
// int64_t row_stride_in_bytes;
const
uint32_t
row_stride_in_bytes
;
// The pointer.
char
*
ptr_
;
// Is the thread active for the last STG?
int
is_active_for_last_stg_
;
// The length of the sequence loaded by that memory tile.
int
actual_seqlen_q
;
const
int
tidx_
;
const
bool
col_predicate
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
int
BYTES_PER_ELEMENT
>
struct
Gmem_tile_mma_sd
{
// The mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// Each STG stores 8 elements.
static
constexpr
int
BYTES_PER_STG
=
BYTES_PER_ELEMENT
*
8
;
// The number of MMAs in the M dimension.
static
constexpr
int
MMAS_M
=
Mma_tile
::
MMAS_M
;
// The number of MMAs in the N dimension.
static
constexpr
int
MMAS_N
=
Mma_tile
::
MMAS_N
;
// The number of rows computed per MMA per thread block.
static
constexpr
int
M_PER_MMA_PER_CTA
=
Mma_tile
::
M_PER_MMA_PER_CTA
;
// The number of cols computed per MMA per thread block.
static
constexpr
int
N_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
;
// The number of threads per block.
static
constexpr
int
THREADS_PER_CTA
=
Cta_tile
::
THREADS_PER_CTA
;
// The size of each row in bytes. I.e. how many bytes are stored per STG.
static
constexpr
int
BYTES_PER_ROW
=
THREADS_PER_CTA
*
BYTES_PER_STG
;
// The distance between elements stored per loop (in bytes).
static
constexpr
int
LOOP_STRIDE_BYTES
=
MMAS_M
*
MMAS_N
*
BYTES_PER_ROW
;
// The type of elements stored per STG.
using
Type
=
typename
fmha
::
Uint_from_size_in_bytes
<
BYTES_PER_STG
>::
Type
;
// Ctor.
template
<
typename
Params
>
inline
__device__
Gmem_tile_mma_sd
(
void
*
ptr
,
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
tidx
)
:
ptr_
(
static_cast
<
char
*>
(
ptr
))
{
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t
bidx
=
bidb
*
params
.
h
+
bidh
;
// The distance between two blocks (in bytes).
// const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
const
uint32_t
block_stride_bytes
=
params
.
seqlen_q
*
params
.
seqlen_k
*
BYTES_PER_ELEMENT
;
// Set store location for each thread at the beginning of the loop
ptr_
+=
bidx
*
block_stride_bytes
+
tidx
*
BYTES_PER_STG
;
}
// Store to global memory.
inline
__device__
void
store
(
const
Type
&
data
,
const
int
mi
,
const
int
ni
)
{
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t
offset
=
(
mi
*
MMAS_N
+
ni
)
*
BYTES_PER_ROW
;
fmha
::
stg
(
ptr_
+
offset
,
data
);
}
// Load from global memory.
inline
__device__
void
load
(
Type
&
data
,
const
int
mi
,
const
int
ni
)
{
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t
offset
=
(
mi
*
MMAS_N
+
ni
)
*
BYTES_PER_ROW
;
fmha
::
ldg
(
data
,
ptr_
+
offset
);
}
// Move to the next tile.
inline
__device__
void
move
(
const
int
steps
=
1
)
{
ptr_
+=
LOOP_STRIDE_BYTES
*
steps
;
}
// The pointer in global memory.
char
*
ptr_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Base
=
Gmem_tile_mma_sd
<
Cta_tile
,
sizeof
(
uint16_t
)>
>
struct
Gmem_tile_mma_s
:
public
Base
{
// The number of mmas in the vertical dimension.
static
constexpr
int
M
=
Base
::
MMAS_M
;
// The number of mmas in the horizontal dimension.
static
constexpr
int
N
=
Base
::
MMAS_N
;
// The type of the vectors stored by each STG.
using
Type
=
typename
Base
::
Type
;
// Ctor.
template
<
typename
Params
,
typename
Block_info
>
inline
__device__
Gmem_tile_mma_s
(
const
Params
&
params
,
const
Block_info
&
binfo
,
const
int
tidx
)
:
Base
(
params
.
s_ptr
,
params
,
binfo
.
bidb
,
binfo
.
bidh
,
tidx
)
{
}
// Store to global memory.
template
<
typename
Mask
,
typename
Fragment
>
inline
__device__
void
store
(
const
Fragment
(
&
frag
)[
N
][
M
],
const
Mask
&
mask
){
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint4
dst
;
dst
.
x
=
frag
[
ni
][
mi
].
reg
(
0
);
dst
.
y
=
frag
[
ni
][
mi
].
reg
(
2
);
dst
.
z
=
frag
[
ni
][
mi
].
reg
(
1
);
dst
.
w
=
frag
[
ni
][
mi
].
reg
(
3
);
if
(
mask
.
any_valid
(
mi
,
ni
)
)
{
Base
::
store
(
dst
,
mi
,
ni
);
}
}
}
}
// Load from global memory.
template
<
typename
Mask
>
inline
__device__
void
load
(
uint4
(
&
regs
)[
M
][
N
],
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
regs
[
mi
][
ni
]
=
make_uint4
(
0
,
0
,
0
,
0
);
if
(
mask
.
any_valid
(
mi
,
ni
)
)
{
Base
::
load
(
regs
[
mi
][
ni
],
mi
,
ni
);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
>
struct
Gmem_summary_stats
{
// The Mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The number of MMAs in M/N dimensions.
static
constexpr
int
MMAS_M
=
Mma_tile
::
MMAS_M
;
// The size of each element.
static
constexpr
int
BYTES_PER_ELEMENT
=
4
;
static
constexpr
int
BYTES_PER_MMA
=
(
Cta_tile
::
THREADS_PER_WARP
/
4
)
*
2
*
BYTES_PER_ELEMENT
;
static
constexpr
int
ROWS
=
Cta_tile
::
M
;
// Ctor.
template
<
typename
Params
>
inline
__device__
Gmem_summary_stats
(
void
*
ptr
,
const
Params
&
params
,
const
int
tidx
)
:
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
)),
tidx_
(
tidx
)
{
// The block index for the batch.
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
y
;
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t
bidx
=
bidb
*
params
.
h
+
bidh
;
// Extract the position in the warp.
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
// The distance between two blocks (in bytes).
// size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
uint32_t
block_stride_bytes
=
params
.
seqlen_q
*
BYTES_PER_ELEMENT
;
// Set store location for each thread at the beginning of the loop
ptr_row_
=
ptr_
+
bidx
*
block_stride_bytes
;
ptr_
+=
bidx
*
block_stride_bytes
+
(
lane
/
4
)
*
BYTES_PER_ELEMENT
;
}
// Store data to global memory.
inline
__device__
void
store
(
const
uint32_t
(
&
data
)[
MMAS_M
*
2
])
{
int
warp
=
tidx_
/
Cta_tile
::
THREADS_PER_WARP
;
int
lane
=
tidx_
%
Cta_tile
::
THREADS_PER_WARP
;
if
((
warp
==
0
)
&&
(
lane
%
4
==
0
))
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
// TODO: Not sure if it's right for MMAS_M > 1
fmha
::
stg
(
ptr_
+
mi
*
BYTES_PER_MMA
+
0
*
BYTES_PER_ELEMENT
,
data
[
mi
*
2
+
0
]);
fmha
::
stg
(
ptr_
+
mi
*
BYTES_PER_MMA
+
8
*
BYTES_PER_ELEMENT
,
data
[
mi
*
2
+
1
]);
}
}
}
// Store data to global memory.
inline
__device__
void
store_row
(
const
uint32_t
(
&
data
)[
MMAS_M
],
const
int
row
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
// TODO: Not sure if it's right for MMAS_M > 1
fmha
::
stg
(
ptr_row_
+
mi
*
BYTES_PER_MMA
+
row
*
BYTES_PER_ELEMENT
,
data
[
mi
]);
}
}
// Load from global memory.
inline
__device__
void
load
(
uint32_t
(
&
data
)[
MMAS_M
*
2
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
// TODO: Not sure if it's right for MMAS_M > 1
fmha
::
ldg
(
data
[
mi
*
2
+
0
],
ptr_
+
mi
*
BYTES_PER_MMA
+
0
*
BYTES_PER_ELEMENT
);
fmha
::
ldg
(
data
[
mi
*
2
+
1
],
ptr_
+
mi
*
BYTES_PER_MMA
+
8
*
BYTES_PER_ELEMENT
);
}
}
// Load from global memory.
inline
__device__
void
load_next
(
uint32_t
(
&
data
)[
MMAS_M
*
2
],
int
move_steps
=
1
)
{
char
*
ptr_next
=
ptr_
+
move_steps
*
ROWS
*
BYTES_PER_ELEMENT
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
// TODO: Not sure if it's right for MMAS_M > 1
fmha
::
ldg
(
data
[
mi
*
2
+
0
],
ptr_next
+
mi
*
BYTES_PER_MMA
+
0
*
BYTES_PER_ELEMENT
);
fmha
::
ldg
(
data
[
mi
*
2
+
1
],
ptr_next
+
mi
*
BYTES_PER_MMA
+
8
*
BYTES_PER_ELEMENT
);
}
}
// Store data to global memory.
template
<
int
N
>
inline
__device__
void
load_row
(
uint32_t
(
&
data
)[
N
],
const
int
row
[
N
])
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
fmha
::
ldg
(
data
[
ni
],
ptr_row_
+
row
[
ni
]
*
BYTES_PER_ELEMENT
);
}
}
// Move the pointer to the next location.
inline
__device__
void
move
()
{
ptr_
+=
ROWS
*
BYTES_PER_ELEMENT
;
ptr_row_
+=
ROWS
*
BYTES_PER_ELEMENT
;
}
// Move the pointer to the next location.
inline
__device__
void
move
(
const
int
steps
)
{
ptr_
+=
ROWS
*
BYTES_PER_ELEMENT
*
steps
;
ptr_row_
+=
ROWS
*
BYTES_PER_ELEMENT
*
steps
;
}
// The pointer.
char
*
ptr_
;
char
*
ptr_row_
;
const
int
tidx_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
csrc/flash_attn/src/fmha/kernel_traits.h
deleted
100644 → 0
View file @
6d48e14a
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <cuda_fp16.h>
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
S
,
int
D
,
int
STEP
,
int
WARPS_M
,
int
WARPS_N
,
uint32_t
FLAGS
=
0x08u
,
typename
elem_type_
=
__half
>
struct
FMHA_kernel_traits
{
// The CTA description for the 1st GEMM.
using
Cta_tile_p
=
fmha
::
Cta_tile_extd
<
STEP
,
S
,
D
,
WARPS_M
,
WARPS_N
,
1
>
;
// The CTA description for the 2nd GEMM.
using
Cta_tile_o
=
fmha
::
Cta_tile_extd
<
STEP
,
D
,
S
,
WARPS_M
,
1
,
WARPS_N
>
;
// Do we use one buffer for K and V.
static
constexpr
bool
SHARE_SMEM_FOR_K_AND_V
=
(
FLAGS
&
0x08u
)
!=
0u
;
// Do we keep K in registers.
static
constexpr
bool
K_IN_REGS
=
(
FLAGS
&
0x10u
)
==
0u
;
// Do we keep V in registers.
static
constexpr
bool
V_IN_REGS
=
(
FLAGS
&
0x100u
)
==
0u
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_A
,
STEP
,
D
>
;
// The shared memory tile to swizzle Q.
// using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
using
Smem_tile_q
=
fmha
::
Smem_tile_a
<
Cta_tile_p
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The global memory tile to load K.
using
Gmem_tile_k
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_B
,
S
,
D
>
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
fmha
::
Smem_tile_b
<
Cta_tile_p
,
fmha
::
Col
>
;
// The global memory tile to load V.
using
Gmem_tile_v
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_o
,
fmha
::
BITS_PER_ELEMENT_B
,
S
,
D
>
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
fmha
::
Smem_tile_v
<
Cta_tile_o
>
;
// The global memory tile to store O.
using
Gmem_tile_o
=
fmha
::
Gmem_tile_o
<
Cta_tile_o
>
;
// The shared memory tile for O.
using
Smem_tile_o
=
fmha
::
Smem_tile_o
<
Cta_tile_o
>
;;
// The global memory tile to load/store S.
using
Gmem_tile_s
=
fmha
::
Gmem_tile_mma_s
<
Cta_tile_p
>
;
// The shared memory tile to transpose S.
using
Smem_tile_st
=
fmha
::
Smem_tile_mma_transposed
<
Cta_tile_p
>
;
using
Gmem_tile_do
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_A
,
STEP
,
D
>
;
// // The global memory tile to store the accumulated dK and dV
// // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV
// // where there are 16 bits per lements and 16 bytes per load. In reality we won't
// // be issue any load or store of size 32 bytes.
// using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv<Cta_tile_o, 32, S, D, 32>;
// The global memory tile to store the softmax sum.
using
Gmem_softmax_sum
=
fmha
::
Gmem_summary_stats
<
Cta_tile_p
>
;
// The shared memory tile to store dp sum.
using
Smem_dp_sum
=
fmha
::
Smem_tile_dp_sum
<
Gmem_tile_q
,
2
>
;
using
elem_type
=
elem_type_
;
// Make sure the number of threads match.
static_assert
((
int
)
Gmem_tile_o
::
THREADS_PER_ROW
==
(
int
)
Smem_tile_o
::
THREADS_PER_ROW
,
""
);
// The number of threads.
static
constexpr
int
THREADS
=
Cta_tile_p
::
THREADS_PER_CTA
;
// Make sure the number of threads matches both CTAs.
static_assert
(
THREADS
==
Cta_tile_o
::
THREADS_PER_CTA
,
""
);
// The amount of shared memory needed to load Q and K.
static
constexpr
int
BYTES_PER_SMEM_QK
=
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
;
// The extra amount of shared memory needed to load V.
static
constexpr
int
BYTES_PER_SMEM_V
=
SHARE_SMEM_FOR_K_AND_V
?
0u
:
Smem_tile_v
::
BYTES_PER_TILE
;
// The amount of shared memory needed for Q, K and V..
static
constexpr
int
BYTES_PER_SMEM_QKV
=
BYTES_PER_SMEM_QK
+
BYTES_PER_SMEM_V
;
// The amount of shared memory needed to load Q and store O.
static
constexpr
int
BYTES_PER_SMEM_QO
=
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
;
// The amount of shared memory needed for Q, K, V and O.
static
constexpr
int
BYTES_PER_SMEM
=
fmha
::
MaxConstexpr
(
BYTES_PER_SMEM_QKV
,
BYTES_PER_SMEM_QO
);
// Make sure we have enough shared memory.
static_assert
(
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
<=
BYTES_PER_SMEM
,
""
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
csrc/flash_attn/src/fmha/mask.h
deleted
100644 → 0
View file @
6d48e14a
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace
fmha
{
template
<
typename
Cta_tile
,
bool
Is_causal
=
false
>
struct
Mask
{
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
template
<
typename
BInfo
>
__device__
Mask
(
const
BInfo
&
binfo
,
int
tidx
,
const
int
loop_step_idx_
=
0
)
:
actual_seqlen_k
(
binfo
.
actual_seqlen_k
-
loop_step_idx_
*
Cta_tile
::
N
)
,
loop_step_idx
(
loop_step_idx_
)
{
const
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
const
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
static_assert
(
Cta_tile
::
WARPS_K
==
1
,
""
);
// find the warp in the Cta tile
const
int
warp_n
=
(
warp
/
Cta_tile
::
WARPS_M
);
const
int
warp_m
=
(
warp
%
Cta_tile
::
WARPS_M
);
// decompose warp into 8x4 tile
const
int
quad
=
lane
/
4
;
const
int
tid
=
(
lane
%
4
)
*
2
;
row
=
warp_m
*
16
+
quad
;
col
=
warp_n
*
16
+
tid
;
}
inline
__device__
bool
is_valid
(
const
int
mi
,
const
int
ni
,
const
int
ii
,
const
int
jj
)
const
{
// ii and jj iterate over the 2x4 fragment
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const
int
current_col
=
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
+
col
+
(
jj
&
2
)
*
4
+
(
jj
&
1
);
const
int
current_row
=
row_offset
+
ii
*
8
;
const
bool
col_valid
=
current_col
<
actual_seqlen_k
;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
// bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) {
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
// }
return
Is_causal
?
col_valid
&&
(
current_col
+
loop_step_idx
*
Cta_tile
::
N
<=
current_row
)
:
col_valid
;
// return row_valid && col_valid;
}
//BERT Mask: if upper left is invalid, none are valid
inline
__device__
bool
any_valid
(
const
int
mi
,
const
int
ni
)
const
{
return
is_valid
(
mi
,
ni
,
0
,
0
)
||
is_valid
(
mi
,
ni
,
1
,
0
);
}
inline
__device__
void
load
(
const
int
it
)
{
row_offset
=
it
*
Cta_tile
::
M
+
row
;
}
int
row_offset
;
int
row
;
int
col
;
const
int
loop_step_idx
;
const
int
actual_seqlen_k
;
};
}
// namespace fmha
csrc/flash_attn/src/fmha/smem_tile.h
deleted
100644 → 0
View file @
6d48e14a
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "utils.h"
#include <fmha/utils.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The description of the tile computed by this CTA.
typename
Cta_tile
,
// The number of rows in the 2D shared memory buffer.
int
M_
,
// The number of cols.
int
N_
,
// The size in bits of each element.
int
BITS_PER_ELEMENT_
,
// The number of bytes per STS.
int
BYTES_PER_STS_
=
16
,
// The number of buffers. (Used in multistage and double buffer cases.)
int
BUFFERS_PER_TILE_
=
1
,
// Do we enable the fast path for LDS.128 and friends.
int
ENABLE_LDS_FAST_PATH_
=
0
,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int
ROWS_PER_XOR_PATTERN_
=
8
,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int
COLS_PER_XOR_PATTERN_
=
1
,
// Use or not predicates
bool
USE_PREDICATES_
=
true
>
struct
Smem_tile_without_skews
{
// The size in bits of each element.
enum
{
BITS_PER_ELEMENT
=
BITS_PER_ELEMENT_
};
// The size in bytes of a single STS.
enum
{
BYTES_PER_STS
=
BYTES_PER_STS_
};
// The number of elements per STS.
enum
{
ELEMENTS_PER_STS
=
BYTES_PER_STS
*
8
/
BITS_PER_ELEMENT
};
// To support arbitrary N, we pad some values to a power-of-2.
enum
{
N_WITH_PADDING
=
Next_power_of_two
<
N_
>::
VALUE
};
// The number of bytes per row without packing of rows.
enum
{
BYTES_PER_ROW_BEFORE_PACKING
=
N_WITH_PADDING
*
BITS_PER_ELEMENT
/
8
};
// The number of bytes per row -- we want at least 128B per row.
enum
{
BYTES_PER_ROW
=
Max
<
BYTES_PER_ROW_BEFORE_PACKING
,
128
>::
VALUE
};
// The number of rows in shared memory (two rows may be packed into a single one).
enum
{
ROWS
=
M_
*
BYTES_PER_ROW_BEFORE_PACKING
/
BYTES_PER_ROW
};
// The number of threads per row.
enum
{
THREADS_PER_ROW_UNBOUNDED
=
BYTES_PER_ROW
/
BYTES_PER_STS
};
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
Min
<
Cta_tile
::
THREADS_PER_CTA
,
THREADS_PER_ROW_UNBOUNDED
>::
VALUE
};
// The number of STS per row.
enum
{
STS_PER_ROW
=
BYTES_PER_ROW
/
THREADS_PER_ROW
/
BYTES_PER_STS
};
// It must be at least one.
static_assert
(
STS_PER_ROW
>=
1
,
""
);
// The number of rows written with a single STS.
enum
{
ROWS_PER_STS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert
(
ROWS_PER_STS
>=
1
,
""
);
// The number of STS needed to store all rows.
enum
{
STS_PER_COL
=
Div_up
<
ROWS
,
ROWS_PER_STS
>::
VALUE
};
// The number of STS in total.
enum
{
STS
=
STS_PER_COL
*
STS_PER_ROW
};
// TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads,
// we only need to store 16 * 64 * 2 = 2KB instead of 4KB.
static
constexpr
bool
PARTIAL_STORE
=
ROWS_PER_STS
>
ROWS
;
static
constexpr
int
STORING_THREADS
=
PARTIAL_STORE
?
ROWS
*
THREADS_PER_ROW
:
Cta_tile
::
THREADS_PER_CTA
;
// The size of one buffer in bytes in shared memory.
// enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
enum
{
BYTES_PER_BUFFER
=
STS
*
BYTES_PER_STS
*
STORING_THREADS
};
// The number of buffers.
enum
{
BUFFERS_PER_TILE
=
BUFFERS_PER_TILE_
};
// The size in bytes of total buffers.
enum
{
BYTES_PER_TILE
=
BYTES_PER_BUFFER
*
BUFFERS_PER_TILE
};
// The boundary for smem_read_offset and smem_write_offset increment.
enum
{
BYTES_PER_TILE_INC_BOUNDARY
=
BYTES_PER_TILE
-
BYTES_PER_BUFFER
};
// Do we enable the LDS.128 fast path?
enum
{
ENABLE_LDS_FAST_PATH
=
ENABLE_LDS_FAST_PATH_
};
static_assert
(
ENABLE_LDS_FAST_PATH
==
0
);
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum
{
ROWS_PER_XOR_PATTERN
=
ROWS_PER_XOR_PATTERN_
};
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum
{
COLS_PER_XOR_PATTERN
=
COLS_PER_XOR_PATTERN_
*
16
/
BYTES_PER_STS
};
// Use or not predicates
enum
{
USE_PREDICATES
=
USE_PREDICATES_
};
// The type of elements that are stored in shared memory by each thread.
using
Store_type
=
typename
Uint_from_size_in_bytes
<
BYTES_PER_STS
>::
Type
;
// Ctor.
inline
__device__
Smem_tile_without_skews
(
void
*
smem
,
int
tidx
)
:
smem_
(
__nvvm_get_smem_pointer
(
smem
)),
tidx_
(
tidx
)
{
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int
smem_write_row
=
tidx
/
THREADS_PER_ROW
;
// The XOR pattern.
int
smem_write_xor
=
smem_write_row
%
ROWS_PER_XOR_PATTERN
*
COLS_PER_XOR_PATTERN
;
// Compute the column and apply the XOR pattern.
int
smem_write_col
=
(
tidx
%
THREADS_PER_ROW
)
^
smem_write_xor
;
// The offset.
this
->
smem_write_offset_
=
smem_write_row
*
BYTES_PER_ROW
+
smem_write_col
*
BYTES_PER_STS
;
// TODO: Why not merge it with the read offset?
// this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
// this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
}
// Compute the store pointers.
template
<
int
N
>
inline
__device__
void
compute_store_pointers
(
uint32_t
(
&
ptrs
)[
N
])
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
// Decompose the STS into row/col.
int
row
=
ii
/
STS_PER_ROW
;
int
col
=
ii
%
STS_PER_ROW
;
// Assemble the offset.
int
offset
=
smem_write_offset_
+
row
*
ROWS_PER_STS
*
BYTES_PER_ROW
;
// Take the column into account.
if
(
STS_PER_ROW
>
1
)
{
offset
+=
col
*
THREADS_PER_ROW
*
BYTES_PER_STS
;
}
// Apply the XOR pattern if needed.
if
(
ROWS_PER_STS
<
ROWS_PER_XOR_PATTERN
)
{
const
int
m
=
row
*
ROWS_PER_STS
%
ROWS_PER_XOR_PATTERN
;
offset
^=
m
*
COLS_PER_XOR_PATTERN
*
BYTES_PER_STS
;
}
// Assemble the final pointer :)
// ptrs[ii] = smem_ + offset + smem_write_buffer_;
// smem_write_buffer_ is already merged with smem_write_offset_
ptrs
[
ii
]
=
smem_
+
offset
;
}
}
inline
__device__
void
debug_reset
()
{
for
(
int
buffer
=
0
;
buffer
<
BYTES_PER_TILE
;
buffer
+=
BYTES_PER_BUFFER
)
{
for
(
int
row
=
0
;
row
<
ROWS
;
++
row
)
{
for
(
int
col
=
0
;
col
<
BYTES_PER_ROW
;
col
+=
4
)
{
if
(
threadIdx
.
x
==
0
)
{
uint32_t
val
=
0x0
;
sts
(
val
,
smem_
+
row
*
BYTES_PER_ROW
+
col
+
buffer
);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline
__device__
void
debug_print
()
const
{
for
(
int
buffer
=
0
;
buffer
<
BYTES_PER_TILE
;
buffer
+=
BYTES_PER_BUFFER
)
{
for
(
int
row
=
0
;
row
<
ROWS
;
++
row
)
{
for
(
int
col
=
0
;
col
<
BYTES_PER_ROW
;
col
+=
4
)
{
if
(
threadIdx
.
x
==
0
)
{
uint32_t
val
;
lds
(
val
,
smem_
+
row
*
BYTES_PER_ROW
+
col
+
buffer
);
printf
(
"block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
smem_
,
buffer
,
row
,
col
,
val
);
}
}
}
}
}
// Move the read offset to next buffer.
inline
__device__
void
move_to_next_read_buffer
()
{
// if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
// this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
// } else if( BUFFERS_PER_TILE > 1 ) {
// this->smem_read_buffer_ += BYTES_PER_BUFFER;
// }
if
(
BUFFERS_PER_TILE
>
1
&&
smem_read_offset_
>=
BYTES_PER_TILE_INC_BOUNDARY
)
{
this
->
smem_read_offset_
-=
BYTES_PER_TILE_INC_BOUNDARY
;
}
else
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_read_offset_
+=
BYTES_PER_BUFFER
;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline
__device__
void
move_next_read_buffer
()
{
this
->
move_to_next_read_buffer
();
}
// Move the read offset to next N buffer (circular-buffer).
inline
__device__
void
move_to_next_read_buffer
(
int
N
)
{
if
(
BUFFERS_PER_TILE
>
1
)
{
// this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
// this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
this
->
smem_read_offset_
+=
N
*
BYTES_PER_BUFFER
;
this
->
smem_read_offset_
-=
smem_read_offset_
>=
BYTES_PER_TILE
?
BYTES_PER_TILE
:
0
;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline
__device__
void
move_next_read_buffer
(
int
N
)
{
this
->
move_to_next_read_buffer
(
N
);
}
// Move the write offset to next buffer.
inline
__device__
void
move_to_next_write_buffer
()
{
// if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
// this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
// } else if( BUFFERS_PER_TILE > 1 ) {
// this->smem_write_buffer_ += BYTES_PER_BUFFER;
// }
if
(
BUFFERS_PER_TILE
>
1
&&
smem_write_offset_
>=
BYTES_PER_TILE_INC_BOUNDARY
)
{
this
->
smem_write_offset_
-=
BYTES_PER_TILE_INC_BOUNDARY
;
}
else
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_write_offset_
+=
BYTES_PER_BUFFER
;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline
__device__
void
move_next_write_buffer
()
{
this
->
move_to_next_write_buffer
();
}
// Move the read offset.
inline
__device__
void
move_read_offset
(
int
delta
)
{
this
->
smem_read_offset_
+=
delta
;
}
// Move the write offset.
inline
__device__
void
move_write_offset
(
int
delta
)
{
this
->
smem_write_offset_
+=
delta
;
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint64_t
=
0
)
{
uint32_t
smem_ptrs
[
N
];
this
->
compute_store_pointers
(
smem_ptrs
);
// Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer.
if
(
!
PARTIAL_STORE
||
(
tidx_
/
THREADS_PER_ROW
<
ROWS
))
{
sts
(
smem_ptrs
,
data
);
}
}
// Store to the tile in shared memory.
template
<
int
N
,
int
M
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint32_t
(
&
preds
)[
M
],
uint64_t
=
0
)
{
uint32_t
smem_ptrs
[
N
];
this
->
compute_store_pointers
(
smem_ptrs
);
sts
(
smem_ptrs
,
data
,
preds
);
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint32_t
preds
,
uint64_t
=
0
)
{
this
->
store
(
data
,
preds
);
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
void
*
(
&
gmem_ptrs
)[
N
],
uint32_t
preds
,
uint64_t
=
0
)
{
uint32_t
tmp
[
1
]
=
{
preds
};
this
->
store
(
gmem_ptrs
,
tmp
);
}
// The shared memory pointer.
const
uint32_t
smem_
;
// The read offset. Reserve 4 offsets if needed.
int
smem_read_offset_
;
// The write offset.
int
smem_write_offset_
;
// The buffer base offset for read.
// int smem_read_buffer_;
// The buffer base offset for write.
// int smem_write_buffer_;
const
int
tidx_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The layout of the tile.
typename
Layout
,
// The size of the STS.
int
BYTES_PER_STS
=
16
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
=
1
,
// Use or not predicates
bool
USE_PREDICATES
=
true
>
struct
Smem_tile_a
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K
,
int
MMAS_K_WITH_PADDING
>
struct
Compute_reset_mask
{
// The potential mask.
enum
{
HALF
=
MMAS_K_WITH_PADDING
/
2
};
// The remainder.
enum
{
MOD
=
MMAS_K
%
HALF
};
// The final value.
enum
{
VALUE
=
(
MMAS_K
==
MOD
?
0
:
HALF
)
|
Compute_reset_mask
<
MOD
,
HALF
>::
VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K_WITH_PADDING
>
struct
Compute_reset_mask
<
0
,
MMAS_K_WITH_PADDING
>
{
enum
{
VALUE
=
0
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K
>
struct
Compute_reset_mask
<
MMAS_K
,
MMAS_K
>
{
enum
{
VALUE
=
MMAS_K
-
1
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_a
{
// The size in bits.
enum
{
N_IN_BITS
=
N
*
fmha
::
BITS_PER_ELEMENT_A
};
// The number of rows.
enum
{
VALUE
=
N_IN_BITS
<=
256
?
2
:
(
N_IN_BITS
<=
512
?
4
:
8
)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_row_a
:
public
Rows_per_xor_pattern_a
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_row_a
<
Cta_tile
::
K
>
::
VALUE
>
struct
Smem_tile_row_a
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
M
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_A
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
M
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_A
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
;
// The fragment.
using
Fragment
=
Fragment_a
<
Row
>
;
// When we use padding to reach a power of two, special care has to be taken.
using
Cta_tile_with_padding
=
Cta_tile_with_k_with_padding
<
Cta_tile
>
;
// The number of MMAs.
using
Mma_tile_with_padding
=
fmha
::
Hmma_tile
<
Cta_tile_with_padding
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// Ctor.
inline
__device__
Smem_tile_row_a
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_K
==
1
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
2
||
Base
::
ROWS_PER_XOR_PATTERN
==
4
||
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
// The row and column read by the thread.
int
smem_read_row
=
(
tidx
&
0x0f
);
constexpr
int
ROWS_PER_PACKING
=
Base
::
BYTES_PER_ROW
/
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
int
smem_read_col
=
((
smem_read_row
/
ROWS_PER_PACKING
)
%
Base
::
ROWS_PER_XOR_PATTERN
)
*
Base
::
COLS_PER_XOR_PATTERN
;
smem_read_col
^=
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
+
smem_read_col
*
BYTES_PER_LDS
;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
a
)[
Mma_tile
::
MMAS_M
],
int
ki
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile
::
MMAS_M
;
++
mi
)
{
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int
offset
=
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
// Load using LDSM.M88.4.
uint4
tmp
;
// ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
offset
);
// Store the value into the fragment.
a
[
mi
].
reg
(
0
)
=
tmp
.
x
;
a
[
mi
].
reg
(
1
)
=
tmp
.
y
;
a
[
mi
].
reg
(
2
)
=
tmp
.
z
;
a
[
mi
].
reg
(
3
)
=
tmp
.
w
;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
this
->
smem_read_offset_
^=
31
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
16
&&
ki
%
8
==
7
)
{
this
->
smem_read_offset_
^=
15
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
8
&&
ki
%
4
==
3
)
{
this
->
smem_read_offset_
^=
7
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
4
&&
ki
%
2
==
1
)
{
this
->
smem_read_offset_
^=
3
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
}
}
// Reset the read offset.
inline
__device__
void
reset_read_offset
()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K
=
Mma_tile
::
MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING
=
Mma_tile_with_padding
::
MMAS_K
};
// Assemble the mask.
enum
{
MASK
=
Compute_reset_mask
<
MMAS_K
,
MMAS_K_WITH_PADDING
>::
VALUE
};
// Reset the read offset.
this
->
smem_read_offset_
^=
MASK
*
BYTES_PER_LDS
*
2
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_a
<
Cta_tile
,
Row
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_row_a
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_row_a
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_a
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The layout of the tile.
typename
Layout
,
// The size of the STS.
int
BYTES_PER_STS
=
16
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
=
1
,
// Use or not predicates
bool
USE_PREDICATES
=
true
>
struct
Smem_tile_b
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_b
{
// The size in bits.
enum
{
N_IN_BITS
=
N
*
fmha
::
BITS_PER_ELEMENT_B
};
// The number of rows.
enum
{
VALUE
=
N_IN_BITS
<=
256
?
2
:
(
N_IN_BITS
<=
512
?
4
:
8
)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_col_b
:
public
Rows_per_xor_pattern_b
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_col_b
<
Cta_tile
::
K
>
::
VALUE
>
struct
Smem_tile_col_b
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
N
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
N
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
Col
>
;
// When we use padding to reach a power of two, special care has to be taken.
using
Cta_tile_with_padding
=
Cta_tile_with_k_with_padding
<
Cta_tile
>
;
// The number of MMAs.
using
Mma_tile_with_padding
=
fmha
::
Hmma_tile
<
Cta_tile_with_padding
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// The number of STS per thread
enum
{
STS_PER_THREAD_
=
Base
::
ROWS
*
Base
::
THREADS_PER_ROW
/
Cta_tile
::
THREADS_PER_CTA
};
// The number of STS per thread must be at least 1.
enum
{
STS_PER_THREAD
=
Max
<
1
,
STS_PER_THREAD_
>::
VALUE
};
// Ctor.
inline
__device__
Smem_tile_col_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
2
||
Base
::
ROWS_PER_XOR_PATTERN
==
4
||
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_K
==
1
);
// The masks to select the warps.
const
int
WARP_MASK_N
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
N
;
// The divisor for the warps.
const
int
WARP_DIV_N
=
WARPS_M
*
1
*
Cta_tile
::
THREADS_PER_WARP
;
// The row and column read by the thread.
int
smem_read_row
=
(
tidx
&
WARP_MASK_N
)
/
WARP_DIV_N
*
Mma_tile
::
N_PER_MMA
+
(
tidx
&
0x07
)
+
(
tidx
&
0x10
)
/
2
;
constexpr
int
ROWS_PER_PACKING
=
Base
::
BYTES_PER_ROW
/
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
int
smem_read_col
=
((
smem_read_row
/
ROWS_PER_PACKING
)
%
Base
::
ROWS_PER_XOR_PATTERN
)
*
Base
::
COLS_PER_XOR_PATTERN
;
smem_read_col
^=
(
tidx
&
0x08
)
/
8
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
+
smem_read_col
*
BYTES_PER_LDS
;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int
offset
=
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
// Load using LDSM.M88.4.
uint4
tmp
;
// ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
offset
);
// Store the value into the fragment.
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
this
->
smem_read_offset_
^=
31
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
16
&&
ki
%
8
==
7
)
{
this
->
smem_read_offset_
^=
15
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
8
&&
ki
%
4
==
3
)
{
this
->
smem_read_offset_
^=
7
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
4
&&
ki
%
2
==
1
)
{
this
->
smem_read_offset_
^=
3
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
}
}
// Reset the read offset.
inline
__device__
void
reset_read_offset
()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K
=
Mma_tile
::
MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING
=
Mma_tile_with_padding
::
MMAS_K
};
// Assemble the mask.
enum
{
MASK
=
Compute_reset_mask
<
MMAS_K
,
MMAS_K_WITH_PADDING
>::
VALUE
};
// Reset the read offset.
this
->
smem_read_offset_
^=
MASK
*
BYTES_PER_LDS
*
2
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_b
<
Cta_tile
,
Col
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_col_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_col_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_row_b
:
public
Rows_per_xor_pattern_b
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_row_b
<
Cta_tile
::
N
>
::
VALUE
,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int
COLS_PER_XOR_PATTERN_
=
1
>
struct
Smem_tile_row_b
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
COLS_PER_XOR_PATTERN_
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
COLS_PER_XOR_PATTERN_
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
Row
>
;
// Can we use LDSM? No if the data type is 32-bit large.
enum
{
USE_LDSMT
=
fmha
::
BITS_PER_ELEMENT_B
==
16
};
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
USE_LDSMT
?
16
:
4
};
// The number of elements per LDS.
enum
{
ELEMENTS_PER_LDS
=
BYTES_PER_LDS
*
8
/
fmha
::
BITS_PER_ELEMENT_B
};
// The number of STS per thread
enum
{
STS_PER_THREAD_
=
Base
::
ROWS
*
Base
::
THREADS_PER_ROW
/
Cta_tile
::
THREADS_PER_CTA
};
// The number of STS per thread must be at least 1.
enum
{
STS_PER_THREAD
=
Max
<
1
,
STS_PER_THREAD_
>::
VALUE
};
// Ctor.
inline
__device__
Smem_tile_row_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
WARPS_K
==
1
);
static_assert
(
WARPS_M
==
4
||
WARPS_M
==
8
);
static_assert
(
WARPS_N
==
1
);
// The masks to select the warps.
const
int
WARP_MASK_N
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
N
;
const
int
WARP_MASK_K
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
K
;
// The divisor for the warps.
const
int
WARP_DIV_N
=
WARPS_M
*
1
*
Cta_tile
::
THREADS_PER_WARP
;
const
int
WARP_DIV_K
=
WARPS_M
*
WARPS_N
*
Cta_tile
::
THREADS_PER_WARP
;
static_assert
(
USE_LDSMT
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
2
||
Base
::
ROWS_PER_XOR_PATTERN
==
4
||
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
// The row/col read by the thread.
int
smem_read_row
=
(
tidx
&
WARP_MASK_K
)
/
WARP_DIV_K
*
Mma_tile
::
MMAS_K
*
16
+
(
tidx
&
0x07
)
+
(
tidx
&
0x08
);
constexpr
int
ROWS_PER_PACKING
=
Base
::
BYTES_PER_ROW
/
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
int
smem_read_col
=
((
smem_read_row
/
ROWS_PER_PACKING
)
%
Base
::
ROWS_PER_XOR_PATTERN
)
*
Base
::
COLS_PER_XOR_PATTERN
;
smem_read_col
^=
(
tidx
&
WARP_MASK_N
)
/
WARP_DIV_N
*
2
+
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
+
smem_read_col
*
BYTES_PER_LDS
;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// The size of each element in bits.
const
int
BITS_PER_ELT
=
fmha
::
BITS_PER_ELEMENT_B
;
// The size in bytes of the data needed to compute an MMA per CTA.
const
int
BYTES_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
*
BITS_PER_ELT
/
8
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
BYTES_PER_MMA_PER_CTA
>=
128
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
&&
Mma_tile
::
MMAS_N
%
2
==
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
// The size of each element in bits.
const
int
BITS_PER_ELT
=
fmha
::
BITS_PER_ELEMENT_B
;
// The size in bytes of the data needed to compute an MMA per CTA.
const
int
BYTES_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
*
BITS_PER_ELT
/
8
;
// uint32_t smem_read_og = this->smem_ + this->smem_read_offset_;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Prepare the offset.
int
offset
=
ki
*
Base
::
ROWS_PER_XOR_PATTERN
*
2
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
if
(
BYTES_PER_MMA_PER_CTA
==
32
)
{
offset
+=
this
->
smem_read_offset_
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
offset
+=
this
->
smem_read_offset_
+
(
ni
/
2
)
*
BYTES_PER_MMA_PER_CTA
*
2
;
}
else
{
offset
+=
this
->
smem_read_offset_
+
(
ni
)
*
BYTES_PER_MMA_PER_CTA
;
}
// Load the data using LDSM.MT88.2.
// uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
uint32_t
ptr
=
this
->
smem_
+
offset
;
uint4
tmp
;
if
(
USE_LDSMT
)
{
ldsmt
(
tmp
,
ptr
);
}
else
{
lds
(
tmp
.
x
,
(
ptr
)
+
0
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
);
lds
(
tmp
.
y
,
(
ptr
)
+
4
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
);
lds
(
tmp
.
z
,
(
ptr
^
32
)
+
0
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
);
lds
(
tmp
.
w
,
(
ptr
^
32
)
+
4
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og);
// }
// Store those values in the fragment.
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if
(
BYTES_PER_MMA_PER_CTA
>=
128
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
8
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
4
==
3
?
14
:
(
ni
%
2
==
1
?
6
:
2
));
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
&&
Mma_tile
::
MMAS_N
%
2
==
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_b
<
Cta_tile
,
Row
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_row_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_row_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_v
:
public
fmha
::
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
16
,
16
,
1
,
0
,
Rows_per_xor_pattern_col_b
<
Cta_tile
::
N
>::
VALUE
,
1
>
{
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
16
,
16
,
1
,
0
,
Rows_per_xor_pattern_col_b
<
Cta_tile
::
N
>::
VALUE
,
1
>
;
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
fmha
::
Col
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// Ctor.
inline
__device__
Smem_tile_v
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// The row/col read by the thread.
int
read_row
,
read_col
;
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
read_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x0f
);
constexpr
int
ROWS_PER_PACKING
=
Base
::
BYTES_PER_ROW
/
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
read_col
=
((
read_row
/
ROWS_PER_PACKING
)
%
Base
::
ROWS_PER_XOR_PATTERN
)
*
Base
::
COLS_PER_XOR_PATTERN
;
read_col
^=
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
read_row
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
+
read_col
*
BYTES_PER_LDS
;
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Jump by 16 * #warps row.
int
row
=
ki
*
16
*
Cta_tile
::
WARPS_K
;
// Load the data using LDSM.MT88.2.
uint4
tmp
;
fmha
::
ldsmt
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
row
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
);
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if
(
Mma_tile
::
MMAS_N
==
1
)
{
// noop
}
else
if
(
Mma_tile
::
MMAS_N
==
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
if
(
Mma_tile
::
MMAS_N
==
8
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
4
==
3
?
14
:
(
ni
%
2
==
1
?
6
:
2
));
}
else
{
assert
(
false
);
// Not implemented!
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_o
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The accumulators.
using
Accumulator
=
fmha
::
Fragment_accumulator
;
// The accumulators.
using
Data_type
=
typename
Accumulator
::
Data_type
;
// The size of each element.
static
constexpr
int
BYTES_PER_ELEMENT
=
sizeof
(
Data_type
);
// The size of each STS.
static
constexpr
int
BYTES_PER_STS
=
8
;
// The size of each row in shared memory.
static
constexpr
int
BYTES_PER_ROW
=
Cta_tile
::
N
*
Cta_tile
::
WARPS_K
*
BYTES_PER_ELEMENT
;
// The size of each LDS.
static
constexpr
int
BYTES_PER_LDS
=
16
;
static
constexpr
int
THREADS_PER_ROW
=
Cta_tile
::
N
*
BYTES_PER_ELEMENT
/
BYTES_PER_LDS
;
// The number of rows.
static
constexpr
int
ROWS
=
Cta_tile
::
M
;
// The number of "rows" to process per loop iteration (in the "epilogue").
static
constexpr
int
ROWS_PER_LOOP
=
ROWS
<=
64
?
ROWS
:
(
int
)
Mma_tile
::
M_PER_MMA_PER_CTA
;
// The number of outer loops.
static
constexpr
int
LOOPS
=
ROWS
/
ROWS_PER_LOOP
;
// Make sure it matches our expectations.
static_assert
(
LOOPS
==
1
||
LOOPS
==
(
int
)
Mma_tile
::
MMAS_M
,
""
);
// The number of rows loaded per LDS.
static
constexpr
int
ROWS_PER_LDS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
;
// Do we have to guard against partial writes/reads.
static
constexpr
bool
HAS_INCOMPLETE_LDS
=
ROWS_PER_LOOP
%
ROWS_PER_LDS
!=
0
;
// The total number of LDS per loop.
static
constexpr
int
LDS_PER_LOOP
=
fmha
::
DivUpConstexpr
(
ROWS_PER_LOOP
,
ROWS_PER_LDS
);
// The amount of shared memory.
static
constexpr
int
BYTES_PER_TILE
=
ROWS_PER_LOOP
*
BYTES_PER_ROW
;
// The write pointer.
uint32_t
smem_write_
,
smem_read_
;
// Is the thread active for the last LDS of the series?
int
is_active_for_last_lds_
;
// static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);
static_assert
(
LOOPS
==
1
||
LOOPS
==
(
int
)
Mma_tile
::
MMAS_M
,
""
);
// Ctor.
inline
__device__
Smem_tile_o
(
void
*
smem
,
int
tidx
)
{
// Get a 32-bit value for the shared memory address.
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
static_assert
(
Cta_tile
::
N
==
16
||
Cta_tile
::
N
==
32
||
Cta_tile
::
N
==
64
||
Cta_tile
::
N
==
128
);
int
write_row
=
(
tidx
&
0x1c
)
/
4
;
const
int
lane
=
tidx
%
32
;
const
int
warp
=
tidx
/
32
;
constexpr
int
ELEMENTS_PER_STS
=
BYTES_PER_STS
/
BYTES_PER_ELEMENT
;
constexpr
int
STS_PER_WARP
=
16
*
Mma_tile
::
MMAS_N
/
ELEMENTS_PER_STS
;
int
write_col
=
warp
*
STS_PER_WARP
+
lane
%
STS_PER_WARP
;
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("write_row = %d, write_col = %d\n", write_row, write_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the write pointer.
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
// The element read by each thread.
int
read_row
=
tidx
/
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
// Take the XOR pattern into account for the column.
read_col
^=
2
*
(
read_row
%
(
Cta_tile
::
N
==
16
?
2
:
(
Cta_tile
::
N
==
32
?
4
:
8
)));
// read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("read_row = %d, read_col = %d\n", read_row, read_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the read pointer.
this
->
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
// Is that thread active on the last LDS?
if
(
HAS_INCOMPLETE_LDS
)
{
this
->
is_active_for_last_lds_
=
read_row
+
(
LDS_PER_LOOP
-
1
)
*
ROWS_PER_LDS
<
Cta_tile
::
M
;
}
}
// Load the output fragments.
template
<
bool
zero_init
=
true
>
inline
__device__
void
load
(
uint4
(
&
out
)[
LDS_PER_LOOP
])
const
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDS_PER_LOOP
;
++
ii
)
{
// Load the elements before the reduction (split-K).
uint4
tmp
[
Cta_tile
::
WARPS_K
];
#pragma unroll
for
(
int
jj
=
0
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
int
imm
=
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
+
jj
*
Cta_tile
::
N
*
BYTES_PER_ELEMENT
;
uint32_t
smem_read
=
this
->
smem_read_
+
imm
;
// TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way.
if
((
Cta_tile
::
N
==
128
)
&&
(
ROWS_PER_LDS
==
4
)
&&
(
ii
%
2
==
1
))
{
smem_read
^=
8
*
BYTES_PER_LDS
;
}
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("imm diff = %d\n", smem_read - this->smem_read_);
// }
if
(
!
HAS_INCOMPLETE_LDS
||
(
ii
<
LDS_PER_LOOP
-
1
||
this
->
is_active_for_last_lds_
)
)
{
// fmha::lds(tmp[jj], this->smem_read_ + imm);
fmha
::
lds
(
tmp
[
jj
],
smem_read
);
}
}
// Perform the reduction.
out
[
ii
]
=
zero_init
?
tmp
[
0
]
:
fmha
::
fadd4
(
out
[
ii
],
tmp
[
0
]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
#pragma unroll
for
(
int
jj
=
1
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
out
[
ii
]
=
fmha
::
fadd4
(
out
[
ii
],
tmp
[
jj
]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
}
}
}
// Store the accumulators.
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
Accumulator
(
&
acc
)[
M
][
N
],
int
mi
)
{
// uint32_t smem_write_og = this->smem_write_;
static
constexpr
int
M_PER_MMA
=
Mma_tile
::
M_PER_MMA_PER_CTA
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// The number of MMAs that are stored per loop iteration.
static
constexpr
int
MMAS_M_PER_LOOP
=
Mma_tile
::
MMAS_M
/
LOOPS
;
// Store 1st column of the different MMAs.
#pragma unroll
for
(
int
mj
=
0
;
mj
<
MMAS_M_PER_LOOP
;
++
mj
)
{
// Precompute the immediates to jump between rows.
int
row_0
=
(
mj
*
M_PER_MMA
+
0
)
*
BYTES_PER_ROW
;
int
row_1
=
(
mj
*
M_PER_MMA
+
8
)
*
BYTES_PER_ROW
;
uint2
tmp0
,
tmp1
;
tmp0
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
0
);
tmp0
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
1
);
tmp1
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
2
);
tmp1
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
3
);
// Store.
fmha
::
sts
(
this
->
smem_write_
+
row_0
,
tmp0
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
// Swizzle the write pointer using a XOR of 16B.
this
->
smem_write_
^=
32
;
// Store 2nd column of the different MMAs.
#pragma unroll
for
(
int
mj
=
0
;
mj
<
MMAS_M_PER_LOOP
;
++
mj
)
{
// Precompute the immediates to jump between rows.
int
row_0
=
(
mj
*
M_PER_MMA
+
0
)
*
BYTES_PER_ROW
;
int
row_1
=
(
mj
*
M_PER_MMA
+
8
)
*
BYTES_PER_ROW
;
uint2
tmp0
,
tmp1
;
tmp0
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
4
);
tmp0
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
5
);
tmp1
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
6
);
tmp1
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
7
);
// Store.
fmha
::
sts
(
this
->
smem_write_
+
row_0
,
tmp0
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
static_assert
(
Mma_tile
::
MMAS_N
<=
8
,
"Not implemented"
);
if
(
Mma_tile
::
MMAS_N
>=
8
&&
ni
%
4
==
3
)
{
this
->
smem_write_
^=
15
*
32
;
}
else
if
(
Mma_tile
::
MMAS_N
>=
4
&&
ni
%
2
==
1
)
{
this
->
smem_write_
^=
7
*
32
;
}
else
if
(
Mma_tile
::
MMAS_N
>=
2
)
{
this
->
smem_write_
^=
3
*
32
;
}
else
{
this
->
smem_write_
^=
3
*
32
;
}
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_mma
{
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
using
Fragment
=
fmha
::
Fragment_a
<
fmha
::
Col
>
;
enum
{
COLS
=
Cta_tile
::
N
};
enum
{
BYTES_PER_ELT
=
2
};
enum
{
BYTES_PER_STS
=
4
};
enum
{
BYTES_PER_ROW
=
COLS
*
BYTES_PER_ELT
};
// TODO
enum
{
BYTES_PER_TILE
=
Cta_tile
::
M
*
BYTES_PER_ROW
};
enum
{
WARPS_M
=
Cta_tile
::
WARPS_M
};
enum
{
WARPS_N
=
Cta_tile
::
WARPS_N
};
enum
{
WARPS_K
=
Cta_tile
::
WARPS_K
};
static_assert
(
WARPS_K
==
1
);
inline
__device__
Smem_tile_mma
(
char
*
smem
,
int
tidx
)
{
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
int
write_col
,
write_row
;
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
||
(
WARPS_M
==
4
||
WARPS_M
==
8
)
||
WARPS_N
==
1
);
if
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
)
{
write_row
=
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0xe0
)
/
4
+
(
tidx
&
0x03
);
write_col
^=
(
write_row
&
0x07
)
*
4
;
}
else
{
write_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0x03
);
// write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4;
write_col
^=
(
write_row
&
(
BYTES_PER_ROW
==
32
?
0x01
:
(
BYTES_PER_ROW
==
64
?
0x03
:
(
BYTES_PER_ROW
==
128
?
0x07
:
0x07
))))
*
4
;
}
// write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
static_assert
(
COLS
==
Cta_tile
::
N
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
// offset ^= 4 * BYTES_PER_STS;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
// size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t
offset
=
smem_write_
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
fmha
::
sts
(
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
x
);
fmha
::
sts
(
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
z
);
offset
^=
4
*
BYTES_PER_STS
;
fmha
::
sts
(
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
y
);
fmha
::
sts
(
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
w
);
}
}
}
template
<
typename
Fragment
,
int
M
,
int
N
>
inline
__device__
void
store
(
const
Fragment
(
&
frag
)[
N
][
M
])
{
static_assert
(
COLS
==
Cta_tile
::
N
);
uint4
regs
[
M
][
N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// Need to transpose ref(1) and reg(2) here since when we load it we transpose again.
regs
[
mi
][
ni
]
=
make_uint4
(
frag
[
ni
][
mi
].
reg
(
0
),
frag
[
ni
][
mi
].
reg
(
2
),
frag
[
ni
][
mi
].
reg
(
1
),
frag
[
ni
][
mi
].
reg
(
3
));
}
}
this
->
store
(
regs
);
}
// uint32_t smem_;
// uint32_t write_offset_;
uint32_t
smem_write_
;
};
template
<
typename
Cta_tile
,
typename
Base
=
Smem_tile_mma
<
Cta_tile
>
>
struct
Smem_tile_mma_transposed
:
public
Base
{
enum
{
BYTES_PER_LDS
=
16
};
enum
{
BYTES_PER_ROW
=
Base
::
BYTES_PER_ROW
};
enum
{
BYTES_PER_ELT
=
Base
::
BYTES_PER_ELT
};
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
using
Fragment
=
typename
Base
::
Fragment
;
inline
__device__
Smem_tile_mma_transposed
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
int
read_row
,
read_col
;
read_row
=
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0xe0
)
/
16
+
(
tidx
&
0x1c
)
/
16
;
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x0f))));
read_col
^=
(
read_row
&
0x07
);
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
template
<
int
M
,
int
N
>
inline
__device__
void
load
(
Fragment
(
&
frag
)[
M
][
N
])
{
static_assert
(
Base
::
COLS
==
Cta_tile
::
N
);
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4
dst
;
// fmha::ldsmt(dst, this->smem_ + offset);
// size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t
offset
=
smem_read_
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
fmha
::
ldsmt
(
dst
,
offset
);
frag
[
mi
][
ni
].
reg
(
0
)
=
dst
.
x
;
frag
[
mi
][
ni
].
reg
(
1
)
=
dst
.
z
;
// Fragment A regs col major!
frag
[
mi
][
ni
].
reg
(
2
)
=
dst
.
y
;
frag
[
mi
][
ni
].
reg
(
3
)
=
dst
.
w
;
}
}
}
// uint32_t read_offset_;
uint32_t
smem_read_
;
};
template
<
typename
Cta_tile
,
typename
Base
=
Smem_tile_mma
<
Cta_tile
>
>
struct
Smem_tile_mma_epilogue
:
public
Base
{
enum
{
BYTES_PER_LDS
=
16
};
enum
{
BYTES_PER_ROW
=
Base
::
BYTES_PER_ROW
};
enum
{
BYTES_PER_ELT
=
Base
::
BYTES_PER_ELT
};
enum
{
THREADS_PER_ROW
=
BYTES_PER_ROW
/
BYTES_PER_LDS
};
static_assert
(
THREADS_PER_ROW
*
BYTES_PER_LDS
==
BYTES_PER_ROW
);
enum
{
ROWS_PER_LDS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
enum
{
NUM_LDS
=
Cta_tile
::
M
/
ROWS_PER_LDS
};
static_assert
(
NUM_LDS
*
ROWS_PER_LDS
==
Cta_tile
::
M
);
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
((
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
using
Acc
=
fmha
::
Fragment_accumulator
;
inline
__device__
Smem_tile_mma_epilogue
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
const
int
read_row
=
tidx
/
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
static_assert
(
Base
::
BYTES_PER_ROW
==
32
||
Base
::
BYTES_PER_ROW
==
64
||
Base
::
BYTES_PER_ROW
==
128
||
Base
::
BYTES_PER_ROW
==
256
);
read_col
^=
(
read_row
&
(
Base
::
BYTES_PER_ROW
==
32
?
0x01
:
(
Base
::
BYTES_PER_ROW
==
64
?
0x03
:
(
Base
::
BYTES_PER_ROW
==
128
?
0x07
:
0x07
))));
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
inline
__device__
void
load
(
uint4
(
&
data
)[
NUM_LDS
])
{
for
(
int
ii
=
0
;
ii
<
NUM_LDS
;
ii
++
)
{
// size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
// fmha::lds(data[ii], this->smem_ + offset);
// size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
uint32_t
offset
=
smem_read_
+
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
;
fmha
::
lds
(
data
[
ii
],
offset
);
}
}
template
<
typename
elem_type
=
__half
,
int
M
,
int
N
>
inline
__device__
void
store
(
const
Acc
(
&
acc
)[
M
][
N
]){
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// 1st row - 4 elements per row.
float
tmp00
=
acc
[
mi
][
ni
].
elt
(
0
);
float
tmp01
=
acc
[
mi
][
ni
].
elt
(
1
);
float
tmp02
=
acc
[
mi
][
ni
].
elt
(
4
);
float
tmp03
=
acc
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
float
tmp10
=
acc
[
mi
][
ni
].
elt
(
2
);
float
tmp11
=
acc
[
mi
][
ni
].
elt
(
3
);
float
tmp12
=
acc
[
mi
][
ni
].
elt
(
6
);
float
tmp13
=
acc
[
mi
][
ni
].
elt
(
7
);
uint32_t
x
=
fmha
::
float2_pack
<
elem_type
>
(
tmp00
,
tmp01
);
uint32_t
y
=
fmha
::
float2_pack
<
elem_type
>
(
tmp02
,
tmp03
);
uint32_t
z
=
fmha
::
float2_pack
<
elem_type
>
(
tmp10
,
tmp11
);
uint32_t
w
=
fmha
::
float2_pack
<
elem_type
>
(
tmp12
,
tmp13
);
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
// offset ^= 4 * Base::BYTES_PER_STS;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t
offset
=
(
this
->
smem_write_
^
(
ni
*
32
))
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
// }
fmha
::
sts
(
offset
+
0
*
BYTES_PER_ROW
,
x
);
fmha
::
sts
(
offset
+
8
*
BYTES_PER_ROW
,
z
);
offset
^=
4
*
Base
::
BYTES_PER_STS
;
fmha
::
sts
(
offset
+
0
*
BYTES_PER_ROW
,
y
);
fmha
::
sts
(
offset
+
8
*
BYTES_PER_ROW
,
w
);
}
}
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t
offset
=
(
this
->
write_offset_
^
(
ni
*
32
))
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
x
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
z
);
offset
^=
4
*
Base
::
BYTES_PER_STS
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
y
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
w
);
}
}
}
// uint32_t read_offset_;
uint32_t
smem_read_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_transpose
{
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
using
Fragment_write
=
fmha
::
Fragment_b
<
fmha
::
Col
>
;
using
Fragment_read
=
fmha
::
Fragment_b
<
fmha
::
Col
>
;
enum
{
COLS
=
Cta_tile
::
N
};
enum
{
BYTES_PER_ELT
=
2
};
enum
{
BYTES_PER_STS
=
4
};
enum
{
BYTES_PER_ROW
=
COLS
*
BYTES_PER_ELT
};
// TODO
enum
{
BYTES_PER_TILE
=
Cta_tile
::
M
*
BYTES_PER_ROW
};
enum
{
BYTES_PER_LDS
=
16
};
enum
{
WARPS_M
=
Cta_tile
::
WARPS_M
};
enum
{
WARPS_N
=
Cta_tile
::
WARPS_N
};
enum
{
WARPS_K
=
Cta_tile
::
WARPS_K
};
static_assert
(
WARPS_K
==
1
);
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
inline
__device__
Smem_tile_transpose
(
char
*
smem
,
int
tidx
)
{
smem_
=
__nvvm_get_smem_pointer
(
smem
);
// uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int
write_col
,
write_row
;
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
||
(
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
if
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
)
{
write_row
=
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0xe0
)
/
4
+
(
tidx
&
0x03
);
}
else
{
write_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0x03
);
}
write_col
^=
(
write_row
&
0x07
)
*
4
;
write_offset_
=
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
// smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
int
read_row
,
read_col
;
read_row
=
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0xe0
)
/
16
+
(
tidx
&
0x1c
)
/
16
;
read_col
^=
(
read_row
&
0x07
);
read_offset_
=
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
// smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
Fragment_write
(
&
frag_w
)[
M
][
N
],
int
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t
offset
=
write_offset_
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
fmha
::
sts
(
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
frag_w
[
ni
][
mi
].
reg
(
0
));
fmha
::
sts
(
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
frag_w
[
ni
][
mi
].
reg
(
2
));
offset
^=
4
*
BYTES_PER_STS
;
fmha
::
sts
(
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
frag_w
[
ni
][
mi
].
reg
(
1
));
fmha
::
sts
(
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
frag_w
[
ni
][
mi
].
reg
(
3
));
}
}
template
<
int
N
>
inline
__device__
void
load
(
Fragment_read
(
&
frag_r
)[
N
])
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t
offset
=
read_offset_
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
uint4
dst
;
fmha
::
ldsmt
(
dst
,
this
->
smem_
+
offset
);
frag_r
[
ni
].
reg
(
0
)
=
dst
.
x
;
frag_r
[
ni
].
reg
(
1
)
=
dst
.
y
;
// Fragment B regs col major!
frag_r
[
ni
].
reg
(
2
)
=
dst
.
z
;
frag_r
[
ni
].
reg
(
3
)
=
dst
.
w
;
}
}
template
<
int
M
,
int
N
>
inline
__device__
void
transpose
(
const
Fragment_write
(
&
frag_w
)[
M
][
N
],
Fragment_read
(
&
frag_r
)[
M
],
int
mi
)
{
static_assert
(
COLS
==
Cta_tile
::
N
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t
offset
=
write_offset_
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
fmha
::
sts
(
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
frag_w
[
ni
][
mi
].
reg
(
0
));
fmha
::
sts
(
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
frag_w
[
ni
][
mi
].
reg
(
2
));
offset
^=
4
*
BYTES_PER_STS
;
fmha
::
sts
(
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
frag_w
[
ni
][
mi
].
reg
(
1
));
fmha
::
sts
(
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
frag_w
[
ni
][
mi
].
reg
(
3
));
}
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t
offset
=
read_offset_
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
uint4
dst
;
fmha
::
ldsmt
(
dst
,
this
->
smem_
+
offset
);
frag_r
[
ni
].
reg
(
0
)
=
dst
.
x
;
frag_r
[
ni
].
reg
(
1
)
=
dst
.
y
;
// Fragment B regs col major!
frag_r
[
ni
].
reg
(
2
)
=
dst
.
z
;
frag_r
[
ni
].
reg
(
3
)
=
dst
.
w
;
}
}
uint32_t
smem_
;
uint32_t
write_offset_
;
uint32_t
read_offset_
;
// uint32_t smem_write_;
// uint32_t smem_read_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Gmem_tile
,
// The number of buffers. (Used in multistage and double buffer cases.)
int
BUFFERS_PER_TILE_
=
1
>
struct
Smem_tile_dp_sum
{
using
Cta_tile
=
typename
Gmem_tile
::
Cta_tile
;
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The size of each element.
static
constexpr
int
BYTES_PER_ELEMENT
=
4
;
static
constexpr
int
ROWS
=
Gmem_tile
::
ROWS
;
static
constexpr
int
THREADS_PER_ROW
=
Gmem_tile
::
THREADS_PER_ROW
;
static
constexpr
int
MMAS_M
=
Mma_tile
::
MMAS_M
;
static
constexpr
int
ROWS_PER_LDG
=
Gmem_tile
::
ROWS_PER_LDG
;
static
constexpr
int
LDGS
=
Gmem_tile
::
LDGS
;
static
constexpr
int
ROWS_PER_MMA
=
Mma_tile
::
M_PER_MMA
;
// The size of one buffer in bytes in shared memory.
static
constexpr
int
BYTES_PER_BUFFER
=
ROWS
*
BYTES_PER_ELEMENT
;
// The number of buffers.
static
constexpr
int
BUFFERS_PER_TILE
=
BUFFERS_PER_TILE_
;
// The size in bytes of total buffers.
static
constexpr
int
BYTES_PER_TILE
=
BYTES_PER_BUFFER
*
BUFFERS_PER_TILE
;
// The boundary for smem_read_offset and smem_write_offset increment.
static
constexpr
int
ROWS_PER_TILE_INC_BOUNDARY
=
ROWS
*
BUFFERS_PER_TILE
-
ROWS
;
inline
__device__
Smem_tile_dp_sum
(
float
*
smem
,
const
int
tidx
)
:
smem_
(
smem
),
smem_read_buffer_
(
smem
),
smem_write_buffer_
(
smem
),
tidx_
(
tidx
)
{
}
// Move the read offset to next buffer.
inline
__device__
void
move_to_next_read_buffer
()
{
if
(
BUFFERS_PER_TILE
>
1
&&
(
smem_read_buffer_
-
smem_
)
>=
ROWS_PER_TILE_INC_BOUNDARY
)
{
this
->
smem_read_buffer_
-=
ROWS_PER_TILE_INC_BOUNDARY
;
}
else
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_read_buffer_
+=
ROWS
;
}
}
// Move the write offset to next buffer.
inline
__device__
void
move_to_next_write_buffer
()
{
if
(
BUFFERS_PER_TILE
>
1
&&
(
smem_write_buffer_
-
smem_
)
>=
ROWS_PER_TILE_INC_BOUNDARY
)
{
this
->
smem_write_buffer_
-=
ROWS_PER_TILE_INC_BOUNDARY
;
}
else
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_write_buffer_
+=
ROWS
;
}
}
inline
__device__
void
store
(
const
float
(
&
sum
)[
LDGS
])
{
if
(
tidx_
%
THREADS_PER_ROW
==
0
)
{
int
row
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
for
(
int
i
=
0
;
i
<
LDGS
;
++
i
)
{
if
(
row
+
i
*
ROWS_PER_LDG
<
ROWS
)
{
smem_write_buffer_
[
row
+
i
*
ROWS_PER_LDG
]
=
sum
[
i
];
}
}
}
}
inline
__device__
void
store
(
const
float
sum
,
const
int
buffer_idx
)
{
float
*
smem_write
=
smem_
+
buffer_idx
*
ROWS
;
int
row
=
tidx_
/
THREADS_PER_ROW
;
if
((
row
<
ROWS
)
&&
(
tidx_
%
THREADS_PER_ROW
==
0
))
{
smem_write
[
row
]
=
sum
;
}
}
inline
__device__
void
store
(
const
float
(
&
sum
)[
LDGS
],
const
int
buffer_idx
)
{
float
*
smem_write
=
smem_
+
buffer_idx
*
ROWS
;
if
(
tidx_
%
THREADS_PER_ROW
==
0
)
{
int
row
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
for
(
int
i
=
0
;
i
<
LDGS
;
++
i
)
{
if
(
row
+
i
*
ROWS_PER_LDG
<
ROWS
)
{
smem_write
[
row
+
i
*
ROWS_PER_LDG
]
=
sum
[
i
];
}
}
}
}
inline
__device__
void
store_pair
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
float
*
smem_write
=
smem_
;
// Extract the position in the warp.
int
warp
=
tidx_
/
Cta_tile
::
THREADS_PER_WARP
;
int
lane
=
tidx_
%
Cta_tile
::
THREADS_PER_WARP
;
int
row
=
lane
/
4
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
smem_write
[
mi
*
ROWS_PER_MMA
+
row
+
0
]
=
sum
[
mi
*
2
+
0
];
smem_write
[
mi
*
ROWS_PER_MMA
+
row
+
8
]
=
sum
[
mi
*
2
+
1
];
}
}
inline
__device__
void
store_pair
(
const
float
(
&
sum
)[
MMAS_M
*
2
],
const
int
buffer_idx
)
{
float
*
smem_write
=
smem_
+
buffer_idx
*
ROWS
;
// Extract the position in the warp.
int
warp
=
tidx_
/
Cta_tile
::
THREADS_PER_WARP
;
int
lane
=
tidx_
%
Cta_tile
::
THREADS_PER_WARP
;
int
row
=
lane
/
4
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
smem_write
[
mi
*
ROWS_PER_MMA
+
row
+
0
]
=
sum
[
mi
*
2
+
0
];
smem_write
[
mi
*
ROWS_PER_MMA
+
row
+
8
]
=
sum
[
mi
*
2
+
1
];
}
}
template
<
int
N
>
inline
__device__
void
load
(
float
(
&
sum
)[
N
],
const
int
(
&
row
)[
N
])
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
sum
[
ni
]
=
smem_read_buffer_
[
row
[
ni
]];
}
}
template
<
int
N
>
inline
__device__
void
load
(
float
(
&
sum
)[
N
],
const
int
(
&
row
)[
N
],
const
int
buffer_idx
)
{
float
*
smem_read
=
smem_
+
buffer_idx
*
ROWS
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
sum
[
ni
]
=
smem_read
[
row
[
ni
]];
}
}
static
inline
__device__
float
reduce_warp
(
float
sum
)
{
fmha
::
SumOp
<
float
>
sum_op
;
return
fmha
::
Allreduce
<
THREADS_PER_ROW
>::
run
(
sum
,
sum_op
);
}
const
int
tidx_
;
float
*
const
smem_
;
float
*
smem_read_buffer_
;
float
*
smem_write_buffer_
;
};
}
// namespace fmha
csrc/flash_attn/src/fmha/softmax.h
deleted
100644 → 0
View file @
6d48e14a
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cmath>
#include <cuda_fp16.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
apply_exp_
(
float
x
,
float
max
)
{
return
__expf
(
x
-
max
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
apply_exp2_
(
float
x
,
float
max
)
{
return
exp2f
(
x
-
max
);
// With fast-math, this produces the same PTX instruction as the assembly below
// float diff = x - max;
// float res;
// asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff));
// return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
COLS
>
struct
ReadType
{};
template
<
>
struct
ReadType
<
4
>
{
using
T
=
float
;};
template
<
>
struct
ReadType
<
8
>
{
using
T
=
float2
;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Smem_tile_reduce
{
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The number of MMAs in M/N dimensions.
static
constexpr
int
MMAS_M
=
Mma_tile
::
MMAS_M
;
static
constexpr
int
MMAS_N
=
Mma_tile
::
MMAS_N
;
static
constexpr
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
static
constexpr
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
static
constexpr
int
ROWS
=
WARPS_M
*
MMAS_M
*
16
;
static
constexpr
int
COLS
=
WARPS_N
;
static_assert
(
COLS
==
4
||
COLS
==
8
);
static
constexpr
int
ROWS_PER_XOR_PATTERN
=
(
COLS
==
8
)
?
4
:
8
;
static
constexpr
int
BYTES_PER_TILE
=
ROWS
*
COLS
*
sizeof
(
float
);
static
constexpr
int
ELTS_PER_TILE
=
ROWS
*
COLS
;
static
constexpr
int
THREADS_PER_GROUP
=
Kernel_traits
::
Gmem_tile_o
::
THREADS_PER_ROW
;
// TD [2022-05-02]: No longer true if head_dim != 64
// static_assert(THREADS_PER_GROUP == 16); // DEBUG
static
constexpr
int
ROWS_PER_WARP
=
32
/
THREADS_PER_GROUP
;
static
constexpr
int
LOOPS
=
Kernel_traits
::
Gmem_tile_o
::
LOOPS
;
static_assert
(
LOOPS
==
1
);
using
read_t
=
typename
ReadType
<
COLS
>::
T
;
__device__
inline
Smem_tile_reduce
(
float
*
smem_
,
const
int
tidx
)
{
int
lane
=
tidx
%
32
;
int
warp
=
tidx
/
32
;
int
warp_m
=
warp
%
WARPS_M
;
int
warp_n
=
warp
/
WARPS_M
;
qid_
=
lane
%
4
;
int
qp
=
lane
/
4
;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const
int
col
=
warp_n
^
(
qp
/
ROWS_PER_XOR_PATTERN
);
smem_write_
=
&
smem_
[
warp_m
*
16
*
MMAS_M
*
WARPS_N
+
qp
*
WARPS_N
+
col
];
smem_read_
=
&
reinterpret_cast
<
read_t
*>
(
smem_
)[
warp_m
*
16
*
MMAS_M
*
4
+
qp
*
4
+
qid_
];
smem_read_row_
=
&
reinterpret_cast
<
read_t
*>
(
smem_
)[
warp_m
*
16
*
MMAS_M
*
4
+
qid_
];
}
__device__
inline
void
store
(
float
(
&
frag
)[
2
*
MMAS_M
])
{
if
(
qid_
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
WARPS_N
;
smem_write_
[
offset
+
0
*
8
*
WARPS_N
]
=
frag
[
mi
*
2
+
0
];
smem_write_
[
offset
+
1
*
8
*
WARPS_N
]
=
frag
[
mi
*
2
+
1
];
}
}
}
__device__
inline
void
load
(
read_t
(
&
frag
)[
2
*
MMAS_M
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
4
;
frag
[
mi
*
2
+
0
]
=
smem_read_
[
offset
+
0
*
8
*
4
];
frag
[
mi
*
2
+
1
]
=
smem_read_
[
offset
+
1
*
8
*
4
];
}
}
__device__
inline
void
load_row
(
read_t
(
&
frag
)[
MMAS_M
],
int
row
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
4
;
frag
[
mi
]
=
smem_read_row_
[
offset
+
0
*
8
*
4
+
row
*
4
];
}
}
int
qid_
;
float
*
smem_write_
;
read_t
*
smem_read_
;
read_t
*
smem_read_row_
;
};
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Softmax_base
{
// The Mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The number of MMAs in M/N dimensions.
static
constexpr
int
MMAS_M
=
Mma_tile
::
MMAS_M
;
static
constexpr
int
MMAS_N
=
Mma_tile
::
MMAS_N
;
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
static
constexpr
int
GROUPS
=
fmha
::
DivUpConstexpr
(
Cta_tile
::
WARPS_N
,
4
);
// The number of elements that we are going to store per row.
static
constexpr
int
ELEMENTS_PER_ROW
=
Cta_tile
::
WARPS_N
/
GROUPS
;
// The number of rows.
static
constexpr
int
ROWS
=
Cta_tile
::
M
*
GROUPS
;
// The total number of elements.
static
constexpr
int
ELEMENTS
=
ROWS
*
ELEMENTS_PER_ROW
;
// Ctor.
template
<
typename
Params
>
inline
__device__
Softmax_base
(
const
Params
&
params
,
void
*
smem
,
int
tidx
)
:
// packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
smem_
(
reinterpret_cast
<
float
*>
(
smem
)),
tidx_
(
tidx
)
{
// Move to the 1st mask loaded by the thread+ tidx;
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
// Extract the position in the warp.
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
// Decompose the warp index into M and N.
int
warp_m
=
warp
%
Cta_tile
::
WARPS_M
;
int
warp_n
=
warp
/
Cta_tile
::
WARPS_M
;
// Decompose the warp-n index into group/position-inside-the-group.
int
warp_g
=
warp_n
/
ELEMENTS_PER_ROW
;
int
warp_i
=
warp_n
%
ELEMENTS_PER_ROW
;
// The location written by the threads.
int
write_row
=
warp_g
*
(
ROWS
/
GROUPS
)
+
warp_m
*
Mma_tile
::
M_PER_MMA
+
lane
/
4
;
int
write_col
=
warp_i
;
// Assemble the write pointer.
smem_write_
=
&
smem_
[
write_row
*
ELEMENTS_PER_ROW
+
write_col
];
// Assemble the read pointer.
smem_read_
=
&
smem_
[
warp_m
*
Mma_tile
::
M_PER_MMA
+
lane
/
4
];
}
template
<
bool
zero
=
false
,
typename
Mask
>
inline
__device__
void
apply_mask
(
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
++
ii
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
if
(
!
mask
.
is_valid
(
mi
,
ni
,
ii
,
jj
)
)
{
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
zero
?
0.
f
:
-
INFINITY
;
}
}
}
}
}
}
// Apply the exp to all the elements.
template
<
bool
max_in_base2
=
false
,
bool
elt_in_base2
=
false
>
inline
__device__
void
apply_exp
(
const
float
(
&
max
)[
MMAS_M
*
2
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
constexpr
float
kLog2e
=
M_LOG2E
;
const
float
max_base2
=
max_in_base2
?
max
[
mi
]
:
max
[
mi
]
*
kLog2e
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
// elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
elt_
[
mi
][
ni
]
=
apply_exp2_
(
elt_in_base2
?
elt_
[
mi
][
ni
]
:
elt_
[
mi
][
ni
]
*
kLog2e
,
max_base2
);
}
}
}
// Apply the exp to all the elements.
template
<
bool
scale_max
=
true
>
inline
__device__
void
scale_apply_exp
(
const
float
(
&
max
)[
MMAS_M
*
2
],
const
float
scale_
)
{
const
float
max_scale
=
scale_max
?
scale_
*
M_LOG2E
:
M_LOG2E
;
const
float
scale
=
scale_
*
M_LOG2E
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
const
float
max_scaled
=
max
[
mi
]
*
max_scale
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
=
apply_exp2_
(
elt_
[
mi
][
ni
]
*
scale
,
max_scaled
);
}
}
}
// Apply the exp to all the elements.
template
<
bool
max_in_base2
=
false
>
inline
__device__
void
apply_exp_col
(
const
float
(
&
max
)[
MMAS_N
*
4
])
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
constexpr
float
kLog2e
=
M_LOG2E
;
const
float
max_base2
=
max_in_base2
?
max
[
ni
]
:
max
[
ni
]
*
kLog2e
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
elt_
[
mi
][
ni
]
=
apply_exp2_
(
elt_
[
mi
][
ni
]
*
kLog2e
,
max_base2
);
}
}
}
// inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) {
// constexpr float kLog2e = M_LOG2E;
// #pragma unroll
// for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
// float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * kLog2e;
// max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + threadIdx.x % 8);
// #pragma unroll
// for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
// elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
// }
// }
// }
template
<
bool
encode_dropout_in_sign_bit
=
false
>
inline
__device__
void
apply_dropout_16bits
(
Philox
&
ph
,
uint16_t
p_dropout_in_uint16_t
)
{
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
(
encode_dropout_in_sign_bit
?
-
val
:
float
(
0
));
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
ni
++
)
{
uint16_t
tmp
[
8
];
// fmha::uint4_to_ushort8(ph(), tmp);
uint4
tmp_32
=
ph
();
fmha
::
uint4_to_ushort8
(
tmp_32
,
tmp
);
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
// }
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
++
ii
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
elt_
[
mi
*
2
+
ii
][
4
*
ni
+
jj
]
=
encode_dropout
(
tmp
[
ii
*
4
+
jj
]
<=
p_dropout_in_uint16_t
,
elt_
[
mi
*
2
+
ii
][
4
*
ni
+
jj
]);
}
}
}
}
}
template
<
bool
encode_dropout_in_sign_bit
=
false
>
inline
__device__
void
apply_dropout_16bits
(
Philox
&
ph
,
uint16_t
p_dropout_in_uint16_t
,
unsigned
long
long
philox_subsequence
)
{
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
(
encode_dropout_in_sign_bit
?
-
val
:
float
(
0
));
};
static_assert
(
MMAS_M
==
1
);
// We're assuming 16x16 blocks.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
ni
++
)
{
uint16_t
tmp
[
8
];
// fmha::uint4_to_ushort8(ph(), tmp);
fmha
::
uint4_to_ushort8
(
ph
(
philox_subsequence
+
ni
*
Cta_tile
::
WARPS_N
),
tmp
);
// uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N);
// fmha::uint4_to_ushort8(tmp_32, tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
// }
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
++
ii
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
elt_
[
mi
*
2
+
ii
][
4
*
ni
+
jj
]
=
encode_dropout
(
tmp
[
ii
*
4
+
jj
]
<=
p_dropout_in_uint16_t
,
elt_
[
mi
*
2
+
ii
][
4
*
ni
+
jj
]);
}
}
}
}
}
template
<
bool
encode_dropout_in_sign_bit
=
false
>
inline
__device__
void
apply_dropout_16bits
(
Philox
&
ph0
,
Philox
&
ph1
,
uint16_t
p_dropout_in_uint16_t
)
{
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
(
encode_dropout_in_sign_bit
?
-
val
:
float
(
0
));
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
static_assert
(
MMAS_N
%
2
==
0
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
ni
+=
2
)
{
uint16_t
tmp
[
8
];
fmha
::
uint4_to_ushort8
(
ph0
(),
tmp
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
++
ii
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
elt_
[
mi
*
2
+
ii
][
4
*
ni
+
jj
]
=
encode_dropout
(
tmp
[
ii
*
4
+
jj
]
<=
p_dropout_in_uint16_t
,
elt_
[
mi
*
2
+
ii
][
4
*
ni
+
jj
]);
}
}
fmha
::
uint4_to_ushort8
(
ph1
(),
tmp
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
++
ii
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
elt_
[
mi
*
2
+
ii
][
4
*
(
ni
+
1
)
+
jj
]
=
encode_dropout
(
tmp
[
ii
*
4
+
jj
]
<=
p_dropout_in_uint16_t
,
elt_
[
mi
*
2
+
ii
][
4
*
(
ni
+
1
)
+
jj
]);
}
}
}
}
}
// Scale all the elements.
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float
inv_sum
[
MMAS_M
*
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
inv_sum
[
mi
]
=
(
sum
[
mi
]
==
0.
f
||
sum
[
mi
]
!=
sum
[
mi
])
?
1.
f
:
1.
f
/
sum
[
mi
];
}
// Update the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
*=
inv_sum
[
mi
];
}
}
}
// Subtract all elements by dp_sum
inline
__device__
void
subtract_dp_sum
(
const
float
(
&
dp_sum
)[
MMAS_M
*
2
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
-=
dp_sum
[
mi
];
}
}
}
// The pointer to the mask.
const
char
*
packed_mask_ptr_
;
// Shared memory for the CTA-wide reduction.
float
*
smem_
,
*
smem_write_
,
*
smem_read_
;
// The current thread index.
int
tidx_
;
// The elements.
float
elt_
[
MMAS_M
*
2
][
MMAS_N
*
4
];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Softmax
:
public
Softmax_base
<
Cta_tile
,
Kernel_traits
>
{
// The base class.
using
Base
=
Softmax_base
<
Cta_tile
,
Kernel_traits
>
;
// The fragment.
using
Fragment_a
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
static_assert
(
Fragment_a
::
NUM_REGS
==
4
);
static
constexpr
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
static
constexpr
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
// The MMAs.
static
constexpr
int
MMAS_M
=
Base
::
MMAS_M
;
static
constexpr
int
MMAS_N
=
Base
::
MMAS_N
;
// The accumulators.
using
Accumulator
=
fmha
::
Fragment_accumulator
;
using
Accumulator_out
=
Fragment
<
uint16_t
,
8
>
;
static_assert
(
Accumulator_out
::
NUM_REGS
==
4
);
static_assert
(
std
::
is_same
<
Accumulator
::
Data_type
,
float
>::
value
);
using
Smem_tile_red
=
Smem_tile_reduce
<
Cta_tile
,
Kernel_traits
>
;
static_assert
(
Smem_tile_red
::
ELTS_PER_TILE
==
Cta_tile
::
M
*
WARPS_N
);
// Ctor.
template
<
typename
Params
>
inline
__device__
Softmax
(
const
Params
&
params
,
void
*
smem
,
int
tidx
)
:
Base
(
params
,
smem
,
tidx
)
,
params_scale_bmm1_
(
params
.
scale_bmm1
)
,
smem_sum_
(
static_cast
<
float
*>
(
smem
),
tidx
)
,
smem_max_
(
static_cast
<
float
*>
(
smem
)
+
Smem_tile_red
::
ELTS_PER_TILE
,
tidx
)
{
}
// Pack the data to a fragment for the next GEMM.
template
<
typename
elem_type
=
__half
,
int
K
,
int
M
>
inline
__device__
void
pack
(
Fragment_a
(
&
dst
)[
K
][
M
])
const
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ki
=
0
;
ki
<
K
;
++
ki
)
{
// 1st row - 4 elements per row.
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
0
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
1
];
float
tmp_02
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
2
];
float
tmp_03
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
3
];
// 2nd row - 4 elements per row.
float
tmp_10
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
0
];
float
tmp_11
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
1
];
float
tmp_12
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
2
];
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
3
];
// Pack to 4 registers.
dst
[
ki
][
mi
].
reg
(
0
)
=
fmha
::
float2_pack
<
elem_type
>
(
tmp_00
,
tmp_01
);
dst
[
ki
][
mi
].
reg
(
1
)
=
fmha
::
float2_pack
<
elem_type
>
(
tmp_10
,
tmp_11
);
dst
[
ki
][
mi
].
reg
(
2
)
=
fmha
::
float2_pack
<
elem_type
>
(
tmp_02
,
tmp_03
);
dst
[
ki
][
mi
].
reg
(
3
)
=
fmha
::
float2_pack
<
elem_type
>
(
tmp_12
,
tmp_13
);
}
}
}
// Scale FP32 fragments
inline
__device__
void
unpack
(
const
Accumulator
(
&
acc
)[
MMAS_M
][
MMAS_N
])
{
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
this
->
params_scale_bmm1_
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
4
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
5
)
*
scalef
;
// 2nd row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
2
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
3
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
6
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
7
)
*
scalef
;
}
}
}
// Scale FP32 fragments
inline
__device__
void
unpack_noscale
(
const
Accumulator
(
&
acc
)[
MMAS_M
][
MMAS_N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
4
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
2
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
3
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
6
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
7
);
}
}
}
template
<
bool
zero_init
=
true
,
typename
Operator
>
__device__
inline
void
thread_reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
2
*
MMAS_M
;
mi
++
)
{
frag
[
mi
]
=
zero_init
?
this
->
elt_
[
mi
][
0
]
:
op
(
frag
[
mi
],
this
->
elt_
[
mi
][
0
]);
#pragma unroll
for
(
int
ni
=
1
;
ni
<
4
*
MMAS_N
;
ni
++
)
{
frag
[
mi
]
=
op
(
frag
[
mi
],
this
->
elt_
[
mi
][
ni
]);
}
}
}
template
<
bool
zero_init
=
true
,
typename
Operator
>
__device__
inline
void
reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
thread_reduce_
<
zero_init
>
(
frag
,
op
);
quad_reduce
(
frag
,
frag
,
op
);
smem_red
.
store
(
frag
);
__syncthreads
();
typename
Smem_tile_red
::
read_t
tmp
[
2
*
MMAS_M
];
smem_red
.
load
(
tmp
);
quad_allreduce
(
frag
,
tmp
,
op
);
}
template
<
bool
zero_init
=
true
>
__device__
inline
void
reduce_max
(
float
(
&
frag
)[
2
*
MMAS_M
]){
MaxOp
<
float
>
max
;
reduce_
<
zero_init
>
(
frag
,
max
,
smem_max_
);
}
__device__
inline
void
reduce_sum
(
float
(
&
frag
)[
2
*
MMAS_M
]){
SumOp
<
float
>
sum
;
reduce_
(
frag
,
sum
,
smem_sum_
);
}
template
<
bool
zero_init
=
true
>
__device__
inline
void
reduce_sum_before_sync_
(
float
(
&
frag
)[
2
*
MMAS_M
]){
SumOp
<
float
>
sum
;
thread_reduce_
<
zero_init
>
(
frag
,
sum
);
quad_reduce
(
frag
,
frag
,
sum
);
smem_sum_
.
store
(
frag
);
}
template
<
int
NROWS
,
typename
Operator
>
__device__
inline
void
reduce_after_sync_
(
float
(
&
frag
)[
NROWS
][
MMAS_M
],
const
int
(
&
rows
)[
NROWS
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
NROWS
;
ii
++
)
{
typename
Smem_tile_red
::
read_t
tmp
[
MMAS_M
];
smem_red
.
load_row
(
tmp
,
rows
[
ii
]);
quad_allreduce
(
frag
[
ii
],
tmp
,
op
);
}
}
template
<
int
NROWS
>
__device__
inline
void
reduce_sum_after_sync_
(
float
(
&
frag
)[
NROWS
][
MMAS_M
],
const
int
(
&
rows
)[
NROWS
]){
SumOp
<
float
>
sum
;
reduce_after_sync_
(
frag
,
rows
,
sum
,
smem_sum_
);
}
template
<
int
NROWS
>
__device__
inline
void
reduce_max_after_sync_
(
float
(
&
frag
)[
NROWS
][
MMAS_M
],
const
int
(
&
rows
)[
NROWS
]){
MaxOp
<
float
>
max
;
reduce_after_sync_
(
frag
,
rows
,
max
,
smem_max_
);
}
const
uint32_t
params_scale_bmm1_
;
Smem_tile_red
smem_max_
;
Smem_tile_red
smem_sum_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
csrc/flash_attn/src/fmha/utils.h
deleted
100644 → 0
View file @
6d48e14a
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
extern
"C"
__device__
uint32_t
__nvvm_get_smem_pointer
(
void
*
ptr
);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Row
{};
struct
Col
{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
bool
=
(
M
&
(
M
-
1
))
==
0
>
struct
Next_power_of_two
{
};
template
<
int
M
>
struct
Next_power_of_two
<
M
,
true
>
{
enum
{
VALUE
=
M
};
};
template
<
>
struct
Next_power_of_two
<
3
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Next_power_of_two
<
5
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
6
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
7
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
9
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
10
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
11
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
12
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
13
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
14
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
15
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
24
,
false
>
{
enum
{
VALUE
=
32
};
};
template
<
>
struct
Next_power_of_two
<
48
,
false
>
{
enum
{
VALUE
=
64
};
};
template
<
>
struct
Next_power_of_two
<
80
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
96
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
112
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
144
,
false
>
{
enum
{
VALUE
=
256
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
bool
=
(
N
&
(
N
-
1
))
==
0
>
struct
Prev_power_of_two
{
};
template
<
int
N
>
struct
Prev_power_of_two
<
N
,
true
>
{
enum
{
VALUE
=
N
};
};
template
<
>
struct
Prev_power_of_two
<
3
,
false
>
{
enum
{
VALUE
=
2
};
};
template
<
>
struct
Prev_power_of_two
<
5
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Prev_power_of_two
<
6
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Prev_power_of_two
<
7
,
false
>
{
enum
{
VALUE
=
4
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
int
N
>
struct
Div_up
{
enum
{
VALUE
=
(
M
+
N
-
1
)
/
N
};
};
constexpr
int
DivUpConstexpr
(
int
M
,
int
N
)
{
return
(
M
+
N
-
1
)
/
N
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
>
struct
Max
{
enum
{
VALUE
=
A
>=
B
?
A
:
B
};
};
constexpr
int
MaxConstexpr
(
int
A
,
int
B
)
{
return
A
>=
B
?
A
:
B
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
,
int
C
>
struct
Max_3
{
enum
{
VALUE
=
Max
<
Max
<
A
,
B
>::
VALUE
,
C
>::
VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
>
struct
Min
{
enum
{
VALUE
=
A
<=
B
?
A
:
B
};
};
constexpr
int
MinConstexpr
(
int
A
,
int
B
)
{
return
A
<=
B
?
A
:
B
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
SIZE_IN_BYTES
>
struct
Uint_from_size_in_bytes
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
1
>
{
using
Type
=
uint8_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
2
>
{
using
Type
=
uint16_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
4
>
{
using
Type
=
uint32_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
8
>
{
using
Type
=
uint2
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
16
>
{
using
Type
=
uint4
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_M
,
int
WARPS_N
,
int
WARPS_K
>
struct
Warp_masks
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Warp_masks
<
8
,
1
,
1
>
{
enum
{
M
=
0xe0
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
4
,
2
,
1
>
{
enum
{
M
=
0x60
,
N
=
0x80
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
4
,
1
,
2
>
{
enum
{
M
=
0x60
,
N
=
0x00
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
4
,
1
,
1
>
{
enum
{
M
=
0x60
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
4
,
1
>
{
enum
{
M
=
0x20
,
N
=
0xc0
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
2
,
2
>
{
enum
{
M
=
0x20
,
N
=
0x40
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
2
,
2
,
1
>
{
enum
{
M
=
0x20
,
N
=
0x40
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
1
,
2
>
{
enum
{
M
=
0x20
,
N
=
0x00
,
K
=
0x40
};
};
template
<
>
struct
Warp_masks
<
2
,
1
,
1
>
{
enum
{
M
=
0x20
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
8
,
1
>
{
enum
{
M
=
0x00
,
N
=
0xe0
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
4
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x60
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
1
,
4
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x60
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
2
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x20
,
K
=
0x40
};
};
template
<
>
struct
Warp_masks
<
1
,
2
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x20
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
4
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x60
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x20
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x00
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
__host__
T
div_up
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
clz
(
int
x
)
{
for
(
int
i
=
31
;
i
>=
0
;
--
i
)
{
if
(
(
1
<<
i
)
&
x
)
{
return
31
-
i
;
}
}
return
32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
find_log_2
(
int
x
,
bool
round_up
=
false
)
{
int
a
=
31
-
clz
(
x
);
if
(
round_up
)
{
a
+=
(
x
&
(
x
-
1
))
?
1
:
0
;
}
return
a
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hadd2
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hmin2
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"min.f16x2 %0, %1, %2;"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hmul2
(
const
uint32_t
a
,
const
uint32_t
b
)
{
// uint32_t c;
// asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
// return c;
__half2
result
=
__hmul2
(
reinterpret_cast
<
const
__half2
(
&
)
>
(
a
),
reinterpret_cast
<
const
__half2
(
&
)
>
(
b
));
return
reinterpret_cast
<
uint32_t
(
&
)
>
(
result
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hmul4
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
hmul2
(
a
.
x
,
b
.
x
);
c
.
y
=
hmul2
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hmul8
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hmul2
(
a
.
x
,
b
.
x
);
c
.
y
=
hmul2
(
a
.
y
,
b
.
y
);
c
.
z
=
hmul2
(
a
.
z
,
b
.
z
);
c
.
w
=
hmul2
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hmul8
(
uint32_t
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hmul2
(
a
,
b
.
x
);
c
.
y
=
hmul2
(
a
,
b
.
y
);
c
.
z
=
hmul2
(
a
,
b
.
z
);
c
.
w
=
hmul2
(
a
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
uint32_t
hrelu2
(
uint32_t
x
);
template
<
>
inline
__device__
uint32_t
hrelu2
<
__half
>
(
uint32_t
x
)
{
uint32_t
res
;
const
uint32_t
zero
=
0u
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"max.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
#else
asm
volatile
(
\
"{
\n
"
\
"
\t
.reg .f16x2 sela;
\n
"
\
"
\t
set.gtu.u32.f16x2 sela, %1, %2;
\n
"
\
"
\t
and.b32 %0, sela, %1;
\n
"
"}
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
#endif
return
res
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
inline
__device__
uint32_t
hrelu2
<
__nv_bfloat16
>
(
uint32_t
x
)
{
uint32_t
res
;
const
uint32_t
zero
=
0u
;
asm
volatile
(
"max.bf16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
return
res
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
habs2
(
uint32_t
x
)
{
uint32_t
res
;
asm
volatile
(
"abs.f16x2 %0, %1;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
));
return
res
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
static
inline
__device__
T
clamp
(
T
x
,
T
lb
,
T
ub
)
{
return
x
<
lb
?
lb
:
(
x
>
ub
?
ub
:
x
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
clamp_to_zero
(
uint16_t
x
)
{
uint16_t
mask
;
asm
volatile
(
"set.gtu %0, %1, 0;"
:
"=h"
(
mask
)
:
"h"
(
x
));
return
mask
&
x
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
float_to_half
(
float
f
)
{
uint16_t
h
;
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
h
)
:
"f"
(
f
));
return
h
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float2_to_half2
(
float
a
,
float
b
)
{
uint32_t
c
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"f"
(
b
),
"f"
(
a
));
#else
uint16_t
lo
=
float_to_half
(
a
);
uint16_t
hi
=
float_to_half
(
b
);
asm
volatile
(
"mov.b32 %0, {%1, %2};
\n
"
:
"=r"
(
c
)
:
"h"
(
lo
),
"h"
(
hi
));
#endif
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
uint32_t
float2_pack
(
float
a
,
float
b
);
template
<
>
inline
__device__
uint32_t
float2_pack
<
__half
>
(
float
a
,
float
b
)
{
__half2
result
=
__floats2half2_rn
(
a
,
b
);
return
reinterpret_cast
<
uint32_t
(
&
)
>
(
result
);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
inline
__device__
uint32_t
float2_pack
<
__nv_bfloat16
>
(
float
a
,
float
b
)
{
__nv_bfloat162
result
=
__floats2bfloat162_rn
(
a
,
b
);
return
reinterpret_cast
<
uint32_t
(
&
)
>
(
result
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float_to_half2
(
float
a
)
{
return
float2_to_half2
(
a
,
a
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float2_to_half2
(
const
float2
&
f
)
{
return
float2_to_half2
(
f
.
x
,
f
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
float4_to_half4
(
float
x
,
float
y
,
float
z
,
float
w
)
{
uint2
d
;
d
.
x
=
float2_to_half2
(
x
,
y
);
d
.
y
=
float2_to_half2
(
z
,
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
uint2
float4_pack
(
float
x
,
float
y
,
float
z
,
float
w
)
{
uint2
d
;
d
.
x
=
float2_pack
<
T
>
(
x
,
y
);
d
.
y
=
float2_pack
<
T
>
(
z
,
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hfma2
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hfma2_relu
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"fma.rn.f16x2.relu %0, %1, %2, %3;"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
#else
d
=
hrelu2
<
__half
>
(
hfma2
(
a
,
b
,
c
));
#endif
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
h0_h0
(
uint32_t
x
)
{
uint32_t
y
;
asm
volatile
(
"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}
\n
"
:
"=r"
(
y
)
:
"r"
(
x
));
return
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
h0_to_float
(
uint32_t
h2
)
{
float
f
;
asm
volatile
(
"{
\n
"
\
".reg .f16 lo, hi;
\n
"
\
"mov.b32 {lo, hi}, %1;
\n
"
\
"cvt.f32.f16 %0, lo;
\n
"
\
"}
\n
"
:
"=f"
(
f
)
:
"r"
(
h2
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
h1_h1
(
uint32_t
x
)
{
uint32_t
y
;
asm
volatile
(
"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}
\n
"
:
"=r"
(
y
)
:
"r"
(
x
));
return
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hadd
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
d
;
asm
volatile
(
"add.f16 %0, %1, %2;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hadd
(
uint32_t
a
,
uint32_t
b
)
{
return
hadd2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hadd4
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
hadd2
(
a
.
x
,
b
.
x
);
c
.
y
=
hadd2
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hadd
(
uint2
a
,
uint2
b
)
{
return
hadd4
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hadd8
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hadd2
(
a
.
x
,
b
.
x
);
c
.
y
=
hadd2
(
a
.
y
,
b
.
y
);
c
.
z
=
hadd2
(
a
.
z
,
b
.
z
);
c
.
w
=
hadd2
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
float2
half2_unpack
(
uint32_t
a
);
template
<
>
inline
__device__
float2
half2_unpack
<
__half
>
(
uint32_t
a
)
{
return
__half22float2
(
reinterpret_cast
<
__half2
(
&
)
>
(
a
));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
inline
__device__
float2
half2_unpack
<
__nv_bfloat16
>
(
uint32_t
a
)
{
return
__bfloat1622float2
(
reinterpret_cast
<
__nv_bfloat162
(
&
)
>
(
a
));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two half2's or bf162's into float, then take their dot product.
template
<
typename
T
>
inline
__device__
float
hfma2_to_float
(
const
uint32_t
a
,
const
uint32_t
b
)
{
float2
af
=
fmha
::
half2_unpack
<
T
>
(
a
);
float2
bf
=
fmha
::
half2_unpack
<
T
>
(
b
);
return
af
.
x
*
bf
.
x
+
af
.
y
*
bf
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template
<
typename
T
>
inline
__device__
float
hmulsum8
(
const
uint4
a
,
const
uint4
b
)
{
float
sum
;
sum
=
fmha
::
hfma2_to_float
<
T
>
(
a
.
x
,
b
.
x
);
sum
+=
fmha
::
hfma2_to_float
<
T
>
(
a
.
y
,
b
.
y
);
sum
+=
fmha
::
hfma2_to_float
<
T
>
(
a
.
z
,
b
.
z
);
sum
+=
fmha
::
hfma2_to_float
<
T
>
(
a
.
w
,
b
.
w
);
return
sum
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
fadd4
(
uint4
a
,
uint4
b
)
{
float4
c
;
c
.
x
=
reinterpret_cast
<
const
float
&>
(
a
.
x
)
+
reinterpret_cast
<
const
float
&>
(
b
.
x
);
c
.
y
=
reinterpret_cast
<
const
float
&>
(
a
.
y
)
+
reinterpret_cast
<
const
float
&>
(
b
.
y
);
c
.
z
=
reinterpret_cast
<
const
float
&>
(
a
.
z
)
+
reinterpret_cast
<
const
float
&>
(
b
.
z
);
c
.
w
=
reinterpret_cast
<
const
float
&>
(
a
.
w
)
+
reinterpret_cast
<
const
float
&>
(
b
.
w
);
return
reinterpret_cast
<
const
uint4
&>
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
fmul4
(
uint4
a
,
float
b
)
{
float4
c
;
c
.
x
=
reinterpret_cast
<
const
float
&>
(
a
.
x
)
*
b
;
c
.
y
=
reinterpret_cast
<
const
float
&>
(
a
.
y
)
*
b
;
c
.
z
=
reinterpret_cast
<
const
float
&>
(
a
.
z
)
*
b
;
c
.
w
=
reinterpret_cast
<
const
float
&>
(
a
.
w
)
*
b
;
return
reinterpret_cast
<
const
uint4
&>
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hadd
(
uint4
a
,
uint4
b
)
{
return
hadd8
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float2
half2_to_float2
(
uint32_t
x
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
x
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
void
half2_to_float2
(
float
&
x
,
float
&
y
,
uint32_t
h
)
{
float2
tmp
=
half2_to_float2
(
h
);
x
=
tmp
.
x
;
y
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hfma
(
uint16_t
a
,
uint16_t
b
,
uint16_t
c
)
{
uint16_t
d
;
asm
volatile
(
"fma.rn.f16 %0, %1, %2, %3;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
),
"h"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hmul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
d
;
asm
volatile
(
"mul.f16 %0, %1, %2;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
void
uint4_to_ushort8
(
const
uint4
a
,
uint16_t
(
&
b
)[
8
])
{
uint32_t
*
b_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
b
[
0
]);
b_tmp
[
0
]
=
a
.
x
;
b_tmp
[
1
]
=
a
.
y
;
b_tmp
[
2
]
=
a
.
z
;
b_tmp
[
3
]
=
a
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
sigmoid
(
float
x
)
{
return
1.
f
/
(
1.
f
+
expf
(
-
x
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint32_t
&
dst
)
{
dst
=
0u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint2
&
dst
)
{
dst
=
make_uint2
(
0u
,
0u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint4
&
dst
)
{
dst
=
make_uint4
(
0u
,
0u
,
0u
,
0u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// P R E D I C A T E P A C K I N G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
enum
{
BYTES_PER_REG
=
4
,
PREDS_PER_BYTE
=
4
,
PREDS_PER_REG
=
BYTES_PER_REG
*
PREDS_PER_BYTE
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// G E N E R I C P R E D I C A T E D L D G S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
,
typename
Functor
>
inline
__device__
void
load_
(
Functor
&
fct
,
const
uint32_t
(
&
preds
)[
M
])
{
// The number of complete bytes (where we use all the predicates in a byte).
enum
{
COMPLETE
=
N
/
PREDS_PER_BYTE
};
// Make sure we did allocate enough predicates.
static_assert
(
Div_up
<
COMPLETE
,
BYTES_PER_REG
>::
VALUE
<=
M
,
""
);
// The remainder.
enum
{
REMAINDER
=
N
-
COMPLETE
*
PREDS_PER_BYTE
};
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert
(
REMAINDER
>=
0
&&
REMAINDER
<=
3
,
""
);
// The mask to extract the predicates.
enum
{
COMPLETE_MASK
=
(
1
<<
PREDS_PER_BYTE
)
-
1
};
// Clear the fetch registers.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
fct
.
clear
(
ii
);
}
// Run complete steps.
bool
p
[
PREDS_PER_BYTE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
COMPLETE
;
++
ii
)
{
// The predicate.
uint32_t
reg
=
preds
[
ii
/
BYTES_PER_REG
];
// Extract the predicates.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
uint32_t
mask
=
1u
<<
(
ii
%
BYTES_PER_REG
*
8
+
jj
);
p
[
jj
]
=
(
reg
&
mask
)
!=
0u
;
}
// Issue the loads.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
fct
.
load
(
ii
*
PREDS_PER_BYTE
+
jj
,
p
[
jj
]);
}
}
// Skip the rest of the code if we do not have a remainder.
if
(
REMAINDER
>
0
)
{
// The mask to extract the predicates.
enum
{
REMAINDER_MASK
=
(
1
<<
REMAINDER
)
-
1
};
// The predicate register.
uint32_t
reg
=
preds
[
COMPLETE
/
BYTES_PER_REG
];
// Extract the predicates.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
uint32_t
mask
=
1u
<<
(
COMPLETE
%
BYTES_PER_REG
*
8
+
jj
);
p
[
jj
]
=
(
reg
&
mask
)
!=
0u
;
}
// Issue the loads.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
REMAINDER
;
++
ii
)
{
fct
.
load
(
COMPLETE
*
PREDS_PER_BYTE
+
ii
,
p
[
ii
]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
typename
Functor
>
inline
__device__
void
load_
(
Functor
&
fct
,
uint32_t
preds
)
{
uint32_t
tmp
[
1
]
=
{
preds
};
load_
<
M
>
(
fct
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint8_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint8_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint16_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint16_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint32_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint32_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint2
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint2
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint4
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint4
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
>
struct
Ldg_functor
{
// Ctor.
inline
__device__
Ldg_functor
(
Data_type
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
])
:
fetch_
(
fetch
),
ptrs_
(
ptrs
)
{
}
// Clear the element.
inline
__device__
void
clear
(
int
ii
)
{
fmha
::
clear
(
fetch_
[
ii
]);
}
// Trigger the loads.
inline
__device__
void
load
(
int
ii
,
bool
p
)
{
if
(
p
)
{
ldg
(
fetch_
[
ii
],
ptrs_
[
ii
]);
}
}
// The fetch registers.
Data_type
(
&
fetch_
)[
N
];
// The pointers.
const
void
*
(
&
ptrs_
)[
N
];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
,
int
M
>
inline
__device__
void
ldg_
(
Data_type
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
Ldg_functor
<
Data_type
,
N
>
fct
(
fetch
,
ptrs
);
load_
<
N
>
(
fct
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint8_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint8_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint16_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint16_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint32_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint32_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint2
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint2
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint4
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint4
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint16_t
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.b16 %0, [%1];
\n
"
:
"=h"
(
dst
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.b32 %0, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint2
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.v2.b32 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint4
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
)
,
"=r"
(
dst
.
y
)
,
"=r"
(
dst
.
z
)
,
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint2
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint2
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint4
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
),
"=r"
(
dst
.
z
),
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint4
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
),
"=r"
(
dst
.
z
),
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint8_t
val
)
{
*
reinterpret_cast
<
uint8_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint16_t
val
)
{
*
reinterpret_cast
<
uint16_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint32_t
val
)
{
*
reinterpret_cast
<
uint32_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint2
val
)
{
*
reinterpret_cast
<
uint2
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint4
val
)
{
*
reinterpret_cast
<
uint4
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint16_t
val
)
{
asm
volatile
(
"st.shared.b16 [%0], %1;
\n
"
:
:
"r"
(
ptr
),
"h"
(
val
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint32_t
val
)
{
asm
volatile
(
"st.shared.b32 [%0], %1;
\n
"
:
:
"r"
(
ptr
),
"r"
(
val
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint2
val
)
{
asm
volatile
(
"st.shared.v2.b32 [%0], {%1, %2};
\n
"
:
:
"r"
(
ptr
)
,
"r"
(
val
.
x
)
,
"r"
(
val
.
y
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint4
val
)
{
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};
\n
"
:
:
"r"
(
ptr
)
,
"r"
(
val
.
x
)
,
"r"
(
val
.
y
)
,
"r"
(
val
.
z
)
,
"r"
(
val
.
w
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
>
inline
__device__
void
sts_
(
uint32_t
(
&
ptrs
)[
N
],
const
Data_type
(
&
data
)[
N
])
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
sts
(
ptrs
[
ii
],
data
[
ii
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint16_t
(
&
data
)[
N
])
{
sts_
<
uint16_t
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint32_t
(
&
data
)[
N
])
{
sts_
<
uint32_t
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint2
(
&
data
)[
N
])
{
sts_
<
uint2
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint4
(
&
data
)[
N
])
{
sts_
<
uint4
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
MaxOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
>
y
?
x
:
y
;
}
};
template
<
>
struct
MaxOp
<
float
>
{
// This is slightly faster
__device__
inline
float
operator
()(
float
const
&
x
,
float
const
&
y
)
{
return
max
(
x
,
y
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_reduce
(
float
(
&
dst
)[
M
],
float
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_reduce
(
__half2
(
&
dst
)[
M
],
__half2
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_reduce
(
float
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
float
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
src
[
mi
].
x
,
src
[
mi
].
y
);
}
quad_reduce
(
dst
,
tmp
,
op
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_reduce
(
__half2
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
__half2
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
reinterpret_cast
<
const
__half2
&>
(
src
[
mi
].
x
),
reinterpret_cast
<
const
__half2
&>
(
src
[
mi
].
y
));
}
quad_reduce
(
dst
,
tmp
,
op
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_allreduce
(
float
(
&
dst
)[
M
],
float
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
Allreduce
<
4
>::
run
(
dst
[
mi
],
op
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_allreduce
(
__half2
(
&
dst
)[
M
],
__half2
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
Allreduce
<
4
>::
run
(
dst
[
mi
],
op
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_allreduce
(
float
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
float
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
src
[
mi
].
x
,
src
[
mi
].
y
);
}
quad_allreduce
(
dst
,
tmp
,
op
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_allreduce
(
__half2
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
__half2
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
reinterpret_cast
<
const
__half2
&>
(
src
[
mi
].
x
),
reinterpret_cast
<
const
__half2
&>
(
src
[
mi
].
y
));
}
quad_allreduce
(
dst
,
tmp
,
op
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
deleted
100644 → 0
View file @
6d48e14a
/* Copyright (c) 2022, Tri Dao.
*/
#include "fmha.h"
#include "fmha_block_dgrad_kernel_1xN_loop.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
>
__global__
void
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
(
FMHA_dgrad_params
params
)
{
fmha
::
compute_block_dq_dk_dv_1xN
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
loop_steps
>
(
params
);
}
template
<
typename
Kernel_traits
>
void
run_fmha_block_dgrad_fp16_sm80_loop_
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_dq
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size_dp_sum
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
typename
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
Kernel_traits
::
Cta_tile_p
::
N
*
2
);
static_assert
(
smem_size_dq
==
16
*
Kernel_traits
::
Cta_tile_p
::
K
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
static_assert
(
smem_size_dp_sum
==
16
*
4
*
2
);
constexpr
int
smem_size_dq_dk_dv
=
smem_size_q
*
2
+
smem_size_v
*
(
Kernel_traits
::
V_IN_REGS
?
1
:
2
)
+
smem_size_dq
+
smem_size_s
*
2
+
smem_size_dp_sum
;
bool
is_dropout
=
params
.
p_dropout
<
1.
f
;
// params.p_dropout is the probability of "keeping"
bool
is_causal
=
params
.
is_causal
;
auto
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
>
)
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
>
);
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
seqlen_k
==
blocksize_c
)
{
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
1
>
)
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
1
>
);
}
else
if
(
params
.
seqlen_k
==
blocksize_c
*
2
)
{
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
2
>
)
:
(
is_causal
?
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
2
>
);
}
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
dim3
grid
(
params
.
b
,
params
.
h
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_block_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
d
==
16
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_block_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
d
==
32
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_block_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
d
==
64
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x100u
>
;
run_fmha_block_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
\ No newline at end of file
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment