Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8c6609ae
Commit
8c6609ae
authored
Dec 08, 2022
by
Tri Dao
Browse files
[LayerNorm] Support all dimensions up to 6k (if divisible by 8)
parent
8a2ece89
Changes
35
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
446 additions
and
420 deletions
+446
-420
csrc/layer_norm/README.md
csrc/layer_norm/README.md
+4
-3
csrc/layer_norm/ln.h
csrc/layer_norm/ln.h
+2
-0
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+17
-4
csrc/layer_norm/ln_bwd_1024.cu
csrc/layer_norm/ln_bwd_1024.cu
+15
-0
csrc/layer_norm/ln_bwd_1280.cu
csrc/layer_norm/ln_bwd_1280.cu
+15
-0
csrc/layer_norm/ln_bwd_1536.cu
csrc/layer_norm/ln_bwd_1536.cu
+15
-0
csrc/layer_norm/ln_bwd_2048.cu
csrc/layer_norm/ln_bwd_2048.cu
+15
-0
csrc/layer_norm/ln_bwd_256.cu
csrc/layer_norm/ln_bwd_256.cu
+15
-0
csrc/layer_norm/ln_bwd_2560.cu
csrc/layer_norm/ln_bwd_2560.cu
+15
-0
csrc/layer_norm/ln_bwd_3072.cu
csrc/layer_norm/ln_bwd_3072.cu
+15
-0
csrc/layer_norm/ln_bwd_4096.cu
csrc/layer_norm/ln_bwd_4096.cu
+15
-0
csrc/layer_norm/ln_bwd_512.cu
csrc/layer_norm/ln_bwd_512.cu
+15
-0
csrc/layer_norm/ln_bwd_5120.cu
csrc/layer_norm/ln_bwd_5120.cu
+15
-0
csrc/layer_norm/ln_bwd_6144.cu
csrc/layer_norm/ln_bwd_6144.cu
+15
-0
csrc/layer_norm/ln_bwd_768.cu
csrc/layer_norm/ln_bwd_768.cu
+15
-0
csrc/layer_norm/ln_bwd_kernels.cuh
csrc/layer_norm/ln_bwd_kernels.cuh
+198
-88
csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
+0
-325
csrc/layer_norm/ln_fwd_1024.cu
csrc/layer_norm/ln_fwd_1024.cu
+15
-0
csrc/layer_norm/ln_fwd_1280.cu
csrc/layer_norm/ln_fwd_1280.cu
+15
-0
csrc/layer_norm/ln_fwd_1536.cu
csrc/layer_norm/ln_fwd_1536.cu
+15
-0
No files found.
csrc/layer_norm/README.md
View file @
8c6609ae
This CUDA extension implements fused dropout + residual + LayerNorm, b
ased
on
This CUDA extension implements fused dropout + residual + LayerNorm, b
uilding
on
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
This only supports a limited set of dimensions, see
`csrc/layer_norm/ln_fwd_cuda_kernel.cu`
.
If you want to use it for dimensions larger than 6k, please file an issue
.
It
has only been tested on A100s.
This extension
has only been tested on A100s.
```
sh
cd
csrc/layer_norm
&&
pip
install
.
...
...
csrc/layer_norm/ln.h
View file @
8c6609ae
...
...
@@ -64,6 +64,8 @@ struct ParamsBase {
void
*
gamma
;
void
*
rowscale
;
float
inverse_cols
;
float
dropout_keep_p
;
float
dropout_scale
;
...
...
csrc/layer_norm/ln_api.cpp
View file @
8c6609ae
...
...
@@ -129,6 +129,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
(
epsilon
>=
0.
f
);
...
...
@@ -156,8 +157,10 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
// Request the kernel launcher.
auto
launcher
=
get_fwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
hidden_size
);
auto
launcher
=
get_fwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
)
);
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
...
...
@@ -178,6 +181,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params
.
z
=
z
.
data_ptr
();
params
.
epsilon
=
epsilon
;
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
if
(
dropout_p
>
0.
f
)
{
// number of times random will be generated per thread, to offset philox counter in thc random
...
...
@@ -263,6 +267,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
}
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
...
...
@@ -285,7 +291,9 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params
.
params
.
dx1
=
has_residual
?
dx1
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
hidden_size
);
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
launcher
(
launch_params
,
true
,
/*prenorm=*/
false
);
...
...
@@ -308,6 +316,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params
.
dbeta_part
=
dbeta_part
.
data_ptr
();
params
.
dgamma_part
=
dgamma_part
.
data_ptr
();
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
...
...
@@ -385,6 +394,8 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
}
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
...
...
@@ -407,8 +418,9 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
launch_params
.
params
.
dx1
=
has_residual
?
dx1
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
// TODO: how to set template param for launcher
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
hidden_size
);
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
auto
launcher
=
get_bwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
launcher
(
launch_params
,
true
,
/*prenorm=*/
true
);
...
...
@@ -432,6 +444,7 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
params
.
dbeta_part
=
dbeta_part
.
data_ptr
();
params
.
dgamma_part
=
dgamma_part
.
data_ptr
();
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
...
...
csrc/layer_norm/ln_bwd_1024.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_bwd_1280.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_bwd_1536.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
csrc/layer_norm/ln_bwd_2048.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_bwd_256.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
256
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
256
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_bwd_2560.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
csrc/layer_norm/ln_bwd_3072.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_bwd_4096.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_bwd_512.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
512
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
512
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_bwd_5120.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_bwd_6144.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_bwd_768.cu
0 → 100644
View file @
8c6609ae
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_bwd_kernels.cuh
View file @
8c6609ae
#pragma once
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Prenorm
,
bool
Is_dropout
,
bool
Has_residual
,
bool
Has_rowscale
>
template
<
typename
Ktraits
,
bool
Prenorm
,
bool
Is_dropout
,
bool
Has_residual
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_bwd_kernel
(
layer_norm
::
BwdParams
params
)
{
...
...
@@ -59,14 +64,18 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
Sum
<
reduce_t
>
sum
;
constexpr
float
rn
=
1.
f
/
float
(
COLS
);
const
index_t
num_valid_ldgs
=
((
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
)
-
1
-
c
+
Ktraits
::
VEC_COLS_PER_LDG
)
/
Ktraits
::
VEC_COLS_PER_LDG
;
Wvec
gamma
[
LDGS
];
index_t
idx
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
...
...
@@ -74,16 +83,18 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
const
compute_t
mu_r
=
static_cast
<
const
compute_t
*>
(
params
.
mu
)[
row
];
const
compute_t
rs_r
=
static_cast
<
const
compute_t
*>
(
params
.
rs
)[
row
];
const
compute_t
rowscale_val
=
Has_rowscale
?
compute_t
(
static_cast
<
const
input_t
*>
(
params
.
rowscale
)[
row
])
:
1.0
f
;
const
compute_t
rowscale_val
=
params
.
rowscale
==
nullptr
?
1.0
f
:
compute_t
(
static_cast
<
const
input_t
*>
(
params
.
rowscale
)[
row
]);
Mvec
dmask
[
LDGS
];
Rvec
dx
[
LDGS
];
compute_t
dy
[
LDGS
*
NUM_ELTS
];
compute_t
y
[
LDGS
*
NUM_ELTS
];
compute_t
mdy_local
=
0.
f
;
compute_t
mdyy_local
=
0.
f
;
index_t
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
index_t
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Rvec
x
;
Ovec
dz
;
dz
.
load_from
(
params
.
dz
,
idx
);
...
...
@@ -95,8 +106,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
x_tmp
=
x
.
data
.
elt
[
jt
];
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
mu_r
);
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
]);
dy_tmp
*=
compute_t
(
dz
.
data
.
elt
[
jt
]);
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
])
*
compute_t
(
dz
.
data
.
elt
[
jt
]);
compute_t
dz_tmp
=
dz
.
data
.
elt
[
jt
];
mdy_local
+=
dy_tmp
;
...
...
@@ -109,14 +119,16 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
dz_sum
[
it
].
data
.
elt
[
jt
]
+=
dz_tmp
;
}
}
}
reduce_t
result
=
reducer
.
allreduce
({
mdy_local
,
mdyy_local
},
sum
);
mdy_local
=
layer_norm
::
Get
<
0
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
rn
;
mdyy_local
=
layer_norm
::
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
rn
;
mdy_local
=
layer_norm
::
Get
<
0
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
params
.
inverse_cols
;
mdyy_local
=
layer_norm
::
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
params
.
inverse_cols
;
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ivec
dx0
;
Rvec
dx1
;
#pragma unroll
...
...
@@ -126,7 +138,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
mdy_local
));
compute_t
dx_tmp_res
=
Prenorm
?
dx_tmp
+
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
dx_tmp
;
if
(
Has_residual
)
{
dx1
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
compute_t
dx0_tmp_res
=
Has_rowscale
?
dx_tmp_res
*
rowscale_val
:
dx_tmp_res
;
compute_t
dx0_tmp_res
=
dx_tmp_res
*
rowscale_val
;
if
(
Is_dropout
)
{
dx0
.
data
.
elt
[
jt
]
=
dmask
[
it
].
data
.
elt
[
jt
]
?
dx0_tmp_res
*
params
.
dropout_scale
:
0.
f
;
}
else
{
...
...
@@ -137,17 +149,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
dx0
.
store_to
(
params
.
dx0
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
// end: grid stride loop
if
(
WARPS_M
==
1
)
{
idx
=
r
*
Ktraits
::
VEC_COLS
+
c
;
idx
=
r
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
dz_sum
[
it
].
store_to
(
params
.
dbeta_part
,
idx
);
dzy_sum
[
it
].
store_to
(
params
.
dgamma_part
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
else
{
static_assert
(
WARPS_M
==
1
||
Ktraits
::
CTAS_PER_ROW
==
1
,
"Multiple rows per CTA not supported for Multi-CTA."
);
// Finalize reduction of part dgamma and dbeta for this CTA
...
...
@@ -188,21 +203,23 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
}
}
compute_t
*
dgamma_part
=
static_cast
<
compute_t
*>
(
params
.
dgamma_part
)
+
bidm
*
COLS
+
tidx
;
const
index_t
num_valid_writes
=
(
params
.
cols
-
1
-
tidx
+
Ktraits
::
THREADS_PER_CTA
)
/
Ktraits
::
THREADS_PER_CTA
;
compute_t
*
dgamma_part
=
static_cast
<
compute_t
*>
(
params
.
dgamma_part
)
+
bidm
*
params
.
cols
+
tidx
;
compute_t
*
dbeta_part
=
static_cast
<
compute_t
*>
(
params
.
dbeta_part
)
+
bidm
*
params
.
cols
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
if
(
Is_even_cols
||
(
jt
<
num_valid_writes
))
{
*
dgamma_part
=
cta_dzy_sum
[
jt
];
dgamma_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
compute_t
*
dbeta_part
=
static_cast
<
compute_t
*>
(
params
.
dbeta_part
)
+
bidm
*
COLS
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dbeta_part
=
cta_dz_sum
[
jt
];
dbeta_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
}
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Kernel_traits
::
THREADS_PER_CTA
)
void
ln_bwd_finalize_kernel
(
BwdParams
params
)
{
...
...
@@ -236,8 +253,10 @@ void ln_bwd_finalize_kernel(BwdParams params)
Vec
<
compute_t
,
NUM_ELT
>
dbeta_local
,
dgamma_local
;
memset
(
&
dgamma_local
,
0
,
sizeof
(
dgamma_local
));
memset
(
&
dbeta_local
,
0
,
sizeof
(
dbeta_local
));
if
(
Is_even_cols
||
col
<
params
.
cols
)
{
for
(
uint32_t
row
=
warp
;
row
<
params
.
ctas_per_col
;
row
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
index_t
idx
=
row
*
Kernel_traits
::
COLS
+
col
;
// index_t idx = row * Kernel_traits::COLS + col;
index_t
idx
=
row
*
params
.
cols
+
col
;
Vec
<
compute_t
,
NUM_ELT
>
dbeta_part
,
dgamma_part
;
dbeta_part
.
load_from
(
params
.
dbeta_part
,
idx
);
...
...
@@ -248,7 +267,7 @@ void ln_bwd_finalize_kernel(BwdParams params)
dbeta_local
.
data
.
elt
[
it
]
+=
dbeta_part
.
data
.
elt
[
it
];
}
}
}
void
*
smem_gamma
=
smem_
;
void
*
smem_beta
=
&
smem_
[
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
...
...
@@ -305,6 +324,7 @@ void ln_bwd_finalize_kernel(BwdParams params)
__syncthreads
();
// Pack and store: 2-wide stores with half the threads.
if
(
Is_even_cols
||
col_out
*
2
<
params
.
cols
)
{
if
(
warp
==
Kernel_traits
::
ROWS_PER_CTA
-
1
&&
lane
<
THREADS_PER_WARP
/
2
)
{
using
src_t
=
typename
TypeToVec2
<
compute_t
>::
Type
;
...
...
@@ -324,5 +344,95 @@ void ln_bwd_finalize_kernel(BwdParams params)
}
}
}
}
}
// namespace layer_norm
using
namespace
layer_norm
;
template
<
typename
weight_t
,
typename
input_t
,
typename
residual_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
void
launch_
(
LaunchParams
<
BwdParams
>
&
launch_params
,
const
bool
configure_params
,
const
bool
prenorm
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
bool
is_dropout
=
launch_params
.
params
.
dropout_keep_p
<
1.
f
;
bool
has_residual
=
launch_params
.
params
.
dx1
!=
nullptr
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
BOOL_SWITCH
(
prenorm
,
PrenormConst
,
[
&
]
{
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
auto
kernel
=
&
ln_bwd_kernel
<
Kernel_traits
,
PrenormConst
,
IsDropoutConst
,
HasResidualConst
,
IsEvenColsConst
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
));
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
}
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
32
*
32
,
// THREADS_PER_CTA
BYTES_PER_LDG_FINAL
>
;
auto
kernel_f
=
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
,
IsEvenColsConst
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
});
});
});
});
}
csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
deleted
100644 → 0
View file @
8a2ece89
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh"
#include "static_switch.h"
using
namespace
layer_norm
;
template
<
typename
weight_t
,
typename
input_t
,
typename
residual_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
void
launch_
(
LaunchParams
<
BwdParams
>
&
launch_params
,
const
bool
configure_params
,
const
bool
prenorm
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
bool
is_dropout
=
launch_params
.
params
.
dropout_keep_p
<
1.
f
;
bool
has_residual
=
launch_params
.
params
.
dx1
!=
nullptr
;
bool
has_rowscale
=
launch_params
.
params
.
rowscale
!=
nullptr
;
BOOL_SWITCH
(
prenorm
,
PrenormConst
,
[
&
]
{
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
has_residual
,
HasResidualConst
,
[
&
]
{
BOOL_SWITCH
(
has_rowscale
,
HasRowscaleConst
,
[
&
]
{
auto
kernel
=
&
ln_bwd_kernel
<
Kernel_traits
,
PrenormConst
,
IsDropoutConst
,
HasResidualConst
,
HasRowscaleConst
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
));
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
}
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
32
*
32
,
// THREADS_PER_CTA
BYTES_PER_LDG_FINAL
>
;
auto
kernel_f
=
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
});
});
});
});
}
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
1600
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
2
,
1
,
4
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
// TD [2022-04-22] Disable most of these to speed up compile time
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
csrc/layer_norm/ln_fwd_1024.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_fwd_1280.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_fwd_1536.cu
0 → 100644
View file @
8c6609ae
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment