Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8c6609ae
Commit
8c6609ae
authored
Dec 08, 2022
by
Tri Dao
Browse files
[LayerNorm] Support all dimensions up to 6k (if divisible by 8)
parent
8a2ece89
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
335 additions
and
375 deletions
+335
-375
csrc/layer_norm/ln_fwd_2048.cu
csrc/layer_norm/ln_fwd_2048.cu
+15
-0
csrc/layer_norm/ln_fwd_256.cu
csrc/layer_norm/ln_fwd_256.cu
+15
-0
csrc/layer_norm/ln_fwd_2560.cu
csrc/layer_norm/ln_fwd_2560.cu
+15
-0
csrc/layer_norm/ln_fwd_3072.cu
csrc/layer_norm/ln_fwd_3072.cu
+15
-0
csrc/layer_norm/ln_fwd_4096.cu
csrc/layer_norm/ln_fwd_4096.cu
+15
-0
csrc/layer_norm/ln_fwd_512.cu
csrc/layer_norm/ln_fwd_512.cu
+15
-0
csrc/layer_norm/ln_fwd_5120.cu
csrc/layer_norm/ln_fwd_5120.cu
+15
-0
csrc/layer_norm/ln_fwd_6144.cu
csrc/layer_norm/ln_fwd_6144.cu
+15
-0
csrc/layer_norm/ln_fwd_768.cu
csrc/layer_norm/ln_fwd_768.cu
+15
-0
csrc/layer_norm/ln_fwd_cuda_kernel.cu
csrc/layer_norm/ln_fwd_cuda_kernel.cu
+0
-302
csrc/layer_norm/ln_fwd_kernels.cuh
csrc/layer_norm/ln_fwd_kernels.cuh
+140
-48
csrc/layer_norm/ln_utils.cuh
csrc/layer_norm/ln_utils.cuh
+27
-18
csrc/layer_norm/setup.py
csrc/layer_norm/setup.py
+24
-2
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+1
-2
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+8
-3
No files found.
csrc/layer_norm/ln_fwd_2048.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_fwd_256.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
256
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
256
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_fwd_2560.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_fwd_3072.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_fwd_4096.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_fwd_512.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
512
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
512
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_fwd_5120.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_fwd_6144.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
csrc/layer_norm/ln_fwd_768.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_fwd_cuda_kernel.cu
deleted
100644 → 0
View file @
8a2ece89
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh"
#include "static_switch.h"
using
namespace
layer_norm
;
template
<
typename
weight_t
,
typename
input_t
,
typename
residual_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
void
launch_
(
LaunchParams
<
FwdParams
>
&
launch_params
,
const
bool
configure_params
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
bool
has_residual
=
launch_params
.
params
.
x1
!=
nullptr
;
bool
has_rowscale
=
launch_params
.
params
.
rowscale
!=
nullptr
;
BOOL_SWITCH
(
launch_params
.
params
.
dropout_keep_p
<
1.
f
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
has_rowscale
,
HasRowscaleConst
,
[
&
]
{
auto
kernel
=
&
ln_fwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
HasResidualConst
,
HasRowscaleConst
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
));
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
const
size_t
rows_per_loop
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
ROWS_PER_CTA
;
launch_params
.
elts_per_thread
=
(
launch_params
.
params
.
rows
+
rows_per_loop
-
1
)
/
rows_per_loop
*
Kernel_traits
::
LDGS
*
Kernel_traits
::
NUM_ELTS
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
Stats
::
stats_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
}
});
});
});
}
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1600
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
1600
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
4
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
// TD [2022-04-22] Disable most of these to speed up compile time
// REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
csrc/layer_norm/ln_fwd_kernels.cuh
View file @
8c6609ae
...
@@ -10,10 +10,13 @@
...
@@ -10,10 +10,13 @@
#include <curand_kernel.h>
#include <curand_kernel.h>
#include "ln.h"
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
namespace
layer_norm
{
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Has_residual
,
bool
Has_rowscale
>
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Has_residual
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_fwd_kernel
(
FwdParams
params
)
{
void
ln_fwd_kernel
(
FwdParams
params
)
{
...
@@ -73,57 +76,70 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -73,57 +76,70 @@ void ln_fwd_kernel(FwdParams params) {
curand_init
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
),
&
state
);
curand_init
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
),
&
state
);
}
}
const
index_t
num_valid_ldgs
=
((
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
)
-
1
-
c
+
VEC_COLS_PER_LDG
)
/
VEC_COLS_PER_LDG
;
Wvec
gamma
[
LDGS
];
Wvec
gamma
[
LDGS
];
Wvec
beta
[
LDGS
];
Wvec
beta
[
LDGS
];
index_t
idx
=
c
;
index_t
idx
=
c
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
constexpr
compute_t
rn
=
1.
f
/
compute_t
(
Ktraits
::
COLS
);
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
const
compute_t
rowscale_val
=
Has_
rowscale
?
compute_t
(
rowscale
[
row
])
:
1.0
f
;
const
compute_t
rowscale_val
=
params
.
rowscale
==
nullptr
?
1.0
f
:
compute_t
(
rowscale
[
row
]);
index_t
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
index_t
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
compute_t
xf
[
LDGS
*
NUM_ELTS
];
compute_t
xf
[
LDGS
*
NUM_ELTS
];
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
Ivec
x0
;
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Rvec
x1
;
Ivec
x0
;
Rvec
x
;
Rvec
x1
;
Mvec
dmask
;
Rvec
x
;
x0
.
load_from
(
params
.
x0
,
idx
);
Mvec
dmask
;
if
(
Has_residual
)
{
x1
.
load_from
(
params
.
x1
,
idx
);
}
x0
.
load_from
(
params
.
x0
,
idx
);
#pragma unroll
if
(
Has_residual
)
{
x1
.
load_from
(
params
.
x1
,
idx
);
}
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
#pragma unroll
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
// the more efficient curand_uniform4.
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
mask_t
keep
=
true
;
// the more efficient curand_uniform4.
if
(
Is_dropout
)
{
mask_t
keep
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
float
rand
=
curand_uniform
(
&
state
);
compute_t
x0_ij
=
compute_t
(
x0
.
data
.
elt
[
jt
])
*
rowscale_val
;
keep
=
mask_t
(
rand
<=
params
.
dropout_keep_p
);
compute_t
x_ij
;
if
(
Has_residual
)
{
compute_t
x1_ij
=
compute_t
(
x1
.
data
.
elt
[
jt
]);
x_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
+
x1_ij
:
x1_ij
;
}
else
{
x_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.
f
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
if
(
Is_dropout
)
{
dmask
.
data
.
elt
[
jt
]
=
keep
;
}
}
}
compute_t
x0_ij
=
Has_rowscale
?
compute_t
(
x0
.
data
.
elt
[
jt
])
*
rowscale_val
:
compute_t
(
x0
.
data
.
elt
[
jt
]);
if
(
save_x
)
{
x
.
store_to
(
params
.
x
,
idx
);
}
compute_t
x_ij
;
if
(
Is_dropout
)
{
dmask
.
store_to
(
params
.
dmask
,
idx
);
}
if
(
Has_residual
)
{
idx
+=
VEC_COLS_PER_LDG
;
compute_t
x1_ij
=
compute_t
(
x1
.
data
.
elt
[
jt
]);
x_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
+
x1_ij
:
x1_ij
;
}
else
{
x_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.
f
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
if
(
Is_dropout
)
{
dmask
.
data
.
elt
[
jt
]
=
keep
;
}
}
}
if
(
save_x
)
{
x
.
store_to
(
params
.
x
,
idx
);
}
if
(
Is_dropout
)
{
dmask
.
store_to
(
params
.
dmask
,
idx
);
}
idx
+=
VEC_COLS_PER_LDG
;
}
}
stats_t
s
=
stats
.
compute
(
xf
,
rn
);
static_assert
(
CTAS_PER_ROW
==
1
,
"Don't support multiple CTAs per row for now"
);
const
index_t
num_vecs
=
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
;
const
index_t
num_full_ldgs
=
num_vecs
/
Ktraits
::
VEC_COLS_PER_LDG
;
const
index_t
remaining_vecs
=
num_vecs
%
Ktraits
::
VEC_COLS_PER_LDG
;
// Need to convert to int, otherwise the subtraction will wrap around.
auto
valid_elts_in_warp_fn
=
[
num_full_ldgs
,
remaining_vecs
]
(
int
warp_n
)
->
int
{
const
index_t
valid_partial_vecs_in_warp
=
std
::
min
(
std
::
max
(
int
(
remaining_vecs
)
-
int
(
warp_n
*
THREADS_PER_WARP
),
int
(
0
)),
int
(
THREADS_PER_WARP
));
return
(
num_full_ldgs
*
THREADS_PER_WARP
+
valid_partial_vecs_in_warp
)
*
NUM_ELTS
;
};
stats_t
s
=
stats
.
template
compute
<
Is_even_cols
>(
xf
,
params
.
inverse_cols
,
valid_elts_in_warp_fn
,
num_valid_ldgs
*
NUM_ELTS
);
compute_t
mu
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
compute_t
>
(
s
);
compute_t
mu
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
compute_t
>
(
s
);
compute_t
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
compute_t
>
(
s
);
compute_t
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
compute_t
>
(
s
);
...
@@ -132,28 +148,104 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -132,28 +148,104 @@ void ln_fwd_kernel(FwdParams params) {
mu_ptr
[
row
]
=
mu
;
mu_ptr
[
row
]
=
mu
;
}
}
compute_t
rs
=
rsqrtf
(
rn
*
m2
+
params
.
epsilon
);
compute_t
rs
=
rsqrtf
(
m2
*
params
.
inverse_cols
+
params
.
epsilon
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
rs_ptr
[
row
]
=
rs
;
rs_ptr
[
row
]
=
rs
;
}
}
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
Ovec
z
;
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
#pragma unroll
Ovec
z
;
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
#pragma unroll
output_t
y_ij
=
output_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu
));
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
output_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu
));
output_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
compute_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
z
.
data
.
elt
[
jt
]
=
(
g_ij
*
y_ij
+
b_ij
);
compute_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
z
.
data
.
elt
[
jt
]
=
output_t
(
g_ij
*
y_ij
+
b_ij
);
}
z
.
store_to
(
params
.
z
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
}
}
z
.
store_to
(
params
.
z
,
idx
);
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
}
}
}
// namespace layer_norm
}
// namespace layer_norm
using
namespace
layer_norm
;
template
<
typename
weight_t
,
typename
input_t
,
typename
residual_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
void
launch_
(
LaunchParams
<
FwdParams
>
&
launch_params
,
const
bool
configure_params
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
bool
has_residual
=
launch_params
.
params
.
x1
!=
nullptr
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
BOOL_SWITCH
(
launch_params
.
params
.
dropout_keep_p
<
1.
f
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
auto
kernel
=
&
ln_fwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
HasResidualConst
,
IsEvenColsConst
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
));
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
const
size_t
rows_per_loop
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
ROWS_PER_CTA
;
launch_params
.
elts_per_thread
=
(
launch_params
.
params
.
rows
+
rows_per_loop
-
1
)
/
rows_per_loop
*
Kernel_traits
::
LDGS
*
Kernel_traits
::
NUM_ELTS
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
Stats
::
stats_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
}
});
});
});
}
csrc/layer_norm/ln_utils.cuh
View file @
8c6609ae
...
@@ -530,20 +530,20 @@ struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
...
@@ -530,20 +530,20 @@ struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
template
<
typename
T
,
typename
int_t
>
inline
__device__
void
warp_chan_upd_dynamic
(
T
&
m_a
,
T
&
m2_a
,
T
&
n_a
,
int
num_active
){
inline
__device__
void
warp_chan_upd_dynamic
(
T
&
m_a
,
T
&
m2_a
,
int_t
&
n_a
,
int
num_active
){
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int
highest_bit_set
=
(
8
*
sizeof
(
num_active
))
-
__clz
(
num_active
-
1
);
const
int
highest_bit_set
=
(
8
*
sizeof
(
num_active
))
-
__clz
(
num_active
-
1
);
#pragma unroll
#pragma unroll
for
(
int
step
=
(
1
<<
(
highest_bit_set
-
1
));
step
>
0
;
step
/=
2
)
{
for
(
int
step
=
(
1
<<
(
highest_bit_set
-
1
));
step
>
0
;
step
/=
2
)
{
// Exchange
// Exchange
T
n_b
=
warp_shuffle_down
(
n_a
,
step
);
int_t
n_b
=
warp_shuffle_down
(
n_a
,
step
);
T
m_b
=
warp_shuffle_down
(
m_a
,
step
);
T
m_b
=
warp_shuffle_down
(
m_a
,
step
);
T
m2_b
=
warp_shuffle_down
(
m2_a
,
step
);
T
m2_b
=
warp_shuffle_down
(
m2_a
,
step
);
// Update
// Update
const
T
n_ab
=
n_a
+
n_b
;
// We can handle one of them being 0, not both.
const
int_t
n_ab
=
n_a
+
n_b
;
// We can handle one of them being 0, not both.
const
T
rn_ab
=
1.
f
/
n_ab
;
// Might have different n per thread, otherwise this would simplify :(
const
T
rn_ab
=
1.
f
/
n_ab
;
// Might have different n per thread, otherwise this would simplify :(
const
T
delta
=
m_a
-
m_b
;
const
T
delta
=
m_a
-
m_b
;
const
float
m2_ab
=
m2_a
+
m2_b
+
delta
*
delta
*
n_a
*
n_b
*
rn_ab
;
const
float
m2_ab
=
m2_a
+
m2_b
+
delta
*
delta
*
n_a
*
n_b
*
rn_ab
;
...
@@ -647,23 +647,26 @@ struct Stats<T, 1, WARPS_M, WARPS_N> {
...
@@ -647,23 +647,26 @@ struct Stats<T, 1, WARPS_M, WARPS_N> {
smem1_
=
smem0_
+
WARPS_M
*
WARPS_N
;
smem1_
=
smem0_
+
WARPS_M
*
WARPS_N
;
}
}
template
<
uint32_t
N
>
template
<
bool
Is_even_cols
,
uint32_t
N
,
typename
function_t
>
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
rn
)
{
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
row_norm_factor
,
function_t
valid_elts_in_warp_fn
,
const
int
num_valid_elts
=
N
)
{
stats_t
*
smem
=
use0_
?
smem0_
:
smem1_
;
stats_t
*
smem
=
use0_
?
smem0_
:
smem1_
;
use0_
=
!
use0_
;
use0_
=
!
use0_
;
// Compute warp local for all WARPS_N
// Compute warp local for all WARPS_N
constexpr
T
warp_rn
=
1.
f
/
T
(
N
*
THREADS_PER_WARP
);
const
auto
warp_n
=
warp_stats_
.
reducer_
.
warp_n_
;
stats_t
warp_stats
=
warp_stats_
.
compute
(
elts
,
warp_rn
);
const
T
warp_norm_factor
=
1.
f
/
T
(
Is_even_cols
?
N
*
THREADS_PER_WARP
:
valid_elts_in_warp_fn
(
warp_n
));
stats_t
warp_stats
=
warp_stats_
.
template
compute
<
Is_even_cols
>(
elts
,
warp_norm_factor
,
valid_elts_in_warp_fn
,
num_valid_elts
);
//Each warp warp leader stores its stats
//Each warp warp leader stores its stats
const
auto
warp_n
=
warp_stats_
.
reducer_
.
warp_n_
;
const
auto
lane
=
warp_stats_
.
reducer_
.
lane_
;
const
auto
lane
=
warp_stats_
.
reducer_
.
lane_
;
if
(
lane
==
0
)
{
if
(
lane
==
0
)
{
smem
[
warp_n
]
=
warp_stats
;
smem
[
warp_n
]
=
warp_stats
;
}
}
__syncthreads
();
__syncthreads
();
T
n
=
Zeros
<
T
>::
get
()
;
int
n
=
0
;
;
T
m
=
Zeros
<
T
>::
get
();
T
m
=
Zeros
<
T
>::
get
();
T
m2
=
Zeros
<
T
>::
get
();
T
m2
=
Zeros
<
T
>::
get
();
...
@@ -671,7 +674,7 @@ struct Stats<T, 1, WARPS_M, WARPS_N> {
...
@@ -671,7 +674,7 @@ struct Stats<T, 1, WARPS_M, WARPS_N> {
static_assert
(
WARPS_N
<=
32
);
static_assert
(
WARPS_N
<=
32
);
if
(
lane
<
WARPS_N
){
if
(
lane
<
WARPS_N
){
stats_t
result
=
smem
[
lane
];
stats_t
result
=
smem
[
lane
];
n
=
N
*
THREADS_PER_WARP
;
n
=
Is_even_cols
?
N
*
THREADS_PER_WARP
:
valid_elts_in_warp_fn
(
lane
)
;
m
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
T
>
(
result
);
m
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
T
>
(
result
);
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
T
>
(
result
);
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
T
>
(
result
);
}
}
...
@@ -703,23 +706,29 @@ struct Stats<T, 1, WARPS_M, 1> {
...
@@ -703,23 +706,29 @@ struct Stats<T, 1, WARPS_M, 1> {
{
{
}
}
template
<
uint32_t
N
>
template
<
bool
Is_even_cols
,
uint32_t
N
,
typename
function_t
>
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
rn
)
{
inline
__device__
stats_t
compute
(
const
T
(
&
elts
)[
N
],
const
T
row_norm_factor
,
// const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) {
function_t
valid_elts_in_warp_fn
,
const
int
num_valid_elts
=
N
)
{
auto
sum
=
Sum
<
T
>
();
auto
sum
=
Sum
<
T
>
();
T
m
=
Zeros
<
T
>::
get
();
T
m
=
Zeros
<
T
>::
get
();
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
N
;
it
++
)
{
for
(
int
it
=
0
;
it
<
N
;
it
++
)
{
m
+=
elts
[
it
];
if
(
Is_even_cols
||
(
it
<
num_valid_elts
))
{
m
+=
elts
[
it
];
}
}
}
m
=
reducer_
.
allreduce
(
m
,
sum
)
*
r
n
;
m
=
reducer_
.
allreduce
(
m
,
sum
)
*
r
ow_norm_factor
;
T
m2
=
Zeros
<
T
>::
get
();
T
m2
=
Zeros
<
T
>::
get
();
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
N
;
it
++
)
{
for
(
int
it
=
0
;
it
<
N
;
it
++
)
{
T
diff
=
(
elts
[
it
]
-
m
);
if
(
Is_even_cols
||
(
it
<
num_valid_elts
))
{
m2
+=
diff
*
diff
;
T
diff
=
(
elts
[
it
]
-
m
);
m2
+=
diff
*
diff
;
}
}
}
m2
=
reducer_
.
allreduce
(
m2
,
sum
);
m2
=
reducer_
.
allreduce
(
m2
,
sum
);
...
...
csrc/layer_norm/setup.py
View file @
8c6609ae
...
@@ -108,8 +108,30 @@ ext_modules.append(
...
@@ -108,8 +108,30 @@ ext_modules.append(
name
=
"dropout_layer_norm"
,
name
=
"dropout_layer_norm"
,
sources
=
[
sources
=
[
"ln_api.cpp"
,
"ln_api.cpp"
,
"ln_fwd_cuda_kernel.cu"
,
"ln_fwd_256.cu"
,
"ln_bwd_semi_cuda_kernel.cu"
,
"ln_bwd_256.cu"
,
"ln_fwd_512.cu"
,
"ln_bwd_512.cu"
,
"ln_fwd_768.cu"
,
"ln_bwd_768.cu"
,
"ln_fwd_1024.cu"
,
"ln_bwd_1024.cu"
,
"ln_fwd_1280.cu"
,
"ln_bwd_1280.cu"
,
"ln_fwd_1536.cu"
,
"ln_bwd_1536.cu"
,
"ln_fwd_2048.cu"
,
"ln_bwd_2048.cu"
,
"ln_fwd_2560.cu"
,
"ln_bwd_2560.cu"
,
"ln_fwd_3072.cu"
,
"ln_bwd_3072.cu"
,
"ln_fwd_4096.cu"
,
"ln_bwd_4096.cu"
,
"ln_fwd_5120.cu"
,
"ln_bwd_5120.cu"
,
"ln_fwd_6144.cu"
,
"ln_bwd_6144.cu"
,
],
],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
generator_flag
,
"cxx"
:
[
"-O3"
]
+
generator_flag
,
...
...
flash_attn/ops/layer_norm.py
View file @
8c6609ae
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
import
torch
import
torch
from
torch.nn
import
init
from
torch.nn
import
init
# from apex._autocast_utils import _cast_if_autocast_enabled
import
dropout_layer_norm
import
dropout_layer_norm
...
@@ -145,7 +144,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
...
@@ -145,7 +144,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.
5
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.
0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
...
...
tests/ops/test_dropout_layer_norm.py
View file @
8c6609ae
...
@@ -24,8 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -24,8 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
(
torch
.
float32
,
torch
.
float32
)]
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3072
,
4096
,
5120
])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [768])
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
):
dropout_p
,
has_residual
,
has_rowscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
...
@@ -148,7 +147,13 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
...
@@ -148,7 +147,13 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3072
,
4096
,
5120
])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
):
dropout_p
,
has_residual
,
has_rowscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment