Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
393882bc
Commit
393882bc
authored
Mar 29, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement LN with parallel residual, support dim 8k
parent
009a3e71
Changes
46
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
717 additions
and
39 deletions
+717
-39
csrc/layer_norm/README.md
csrc/layer_norm/README.md
+7
-4
csrc/layer_norm/ln.h
csrc/layer_norm/ln.h
+43
-3
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+408
-32
csrc/layer_norm/ln_bwd_7168.cu
csrc/layer_norm/ln_bwd_7168.cu
+15
-0
csrc/layer_norm/ln_bwd_8192.cu
csrc/layer_norm/ln_bwd_8192.cu
+15
-0
csrc/layer_norm/ln_fwd_7168.cu
csrc/layer_norm/ln_fwd_7168.cu
+15
-0
csrc/layer_norm/ln_fwd_8192.cu
csrc/layer_norm/ln_fwd_8192.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_1024.cu
csrc/layer_norm/ln_parallel_bwd_1024.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_1280.cu
csrc/layer_norm/ln_parallel_bwd_1280.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_1536.cu
csrc/layer_norm/ln_parallel_bwd_1536.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_2048.cu
csrc/layer_norm/ln_parallel_bwd_2048.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_256.cu
csrc/layer_norm/ln_parallel_bwd_256.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_2560.cu
csrc/layer_norm/ln_parallel_bwd_2560.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_3072.cu
csrc/layer_norm/ln_parallel_bwd_3072.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_4096.cu
csrc/layer_norm/ln_parallel_bwd_4096.cu
+17
-0
csrc/layer_norm/ln_parallel_bwd_512.cu
csrc/layer_norm/ln_parallel_bwd_512.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_5120.cu
csrc/layer_norm/ln_parallel_bwd_5120.cu
+17
-0
csrc/layer_norm/ln_parallel_bwd_6144.cu
csrc/layer_norm/ln_parallel_bwd_6144.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_7168.cu
csrc/layer_norm/ln_parallel_bwd_7168.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_768.cu
csrc/layer_norm/ln_parallel_bwd_768.cu
+15
-0
No files found.
csrc/layer_norm/README.md
View file @
393882bc
This CUDA extension implements fused dropout + residual + LayerNorm, building on
This CUDA extension implements fused dropout + residual + LayerNorm, building on
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
Major changes:
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
-
Add dropout and residual.
We also implement RMSNorm as an option.
-
Make it work for both pre-norm and post-norm architecture.
-
Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
-
Implement RMSNorm as an option.
-
Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).
If you want to use it for dimensions larger than
6
k, please file an issue.
If you want to use it for dimensions larger than
8
k, please file an issue.
This extension has only been tested on A100s.
This extension has only been tested on A100s.
...
...
csrc/layer_norm/ln.h
View file @
393882bc
...
@@ -14,7 +14,7 @@ namespace layer_norm {
...
@@ -14,7 +14,7 @@ namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Params
>
template
<
typename
Params
>
struct
LaunchParams
{
struct
LaunchParams
{
size_t
elts_per_thread
;
size_t
elts_per_thread
;
...
@@ -40,6 +40,7 @@ struct ParamsBase {
...
@@ -40,6 +40,7 @@ struct ParamsBase {
,
mu
(
nullptr
)
,
mu
(
nullptr
)
,
rs
(
nullptr
)
,
rs
(
nullptr
)
,
gamma
(
nullptr
)
,
gamma
(
nullptr
)
,
gamma1
(
nullptr
)
,
rowscale
(
nullptr
)
,
rowscale
(
nullptr
)
,
colscale
(
nullptr
)
,
colscale
(
nullptr
)
,
dropout_keep_p
(
1.
f
)
,
dropout_keep_p
(
1.
f
)
...
@@ -59,12 +60,15 @@ struct ParamsBase {
...
@@ -59,12 +60,15 @@ struct ParamsBase {
// Common data pointers.
// Common data pointers.
void
*
x0
;
void
*
x0
;
void
*
x1
;
void
*
residual
;
void
*
residual
;
void
*
x
;
void
*
x
;
void
*
dmask
;
void
*
dmask
;
void
*
dmask1
;
void
*
mu
;
void
*
mu
;
void
*
rs
;
void
*
rs
;
void
*
gamma
;
void
*
gamma
;
void
*
gamma1
;
void
*
rowscale
;
void
*
rowscale
;
void
*
colscale
;
void
*
colscale
;
void
*
x0_subset
;
void
*
x0_subset
;
...
@@ -92,14 +96,18 @@ struct FwdParams : public ParamsBase {
...
@@ -92,14 +96,18 @@ struct FwdParams : public ParamsBase {
FwdParams
()
FwdParams
()
:
ParamsBase
()
:
ParamsBase
()
,
z
(
nullptr
)
,
z
(
nullptr
)
,
z1
(
nullptr
)
,
beta
(
nullptr
)
,
beta
(
nullptr
)
,
beta1
(
nullptr
)
,
epsilon
(
0.
f
)
,
epsilon
(
0.
f
)
{
{
}
}
// Output of LN FWD.
// Output of LN FWD.
void
*
z
;
void
*
z
;
void
*
z1
;
void
*
beta
;
void
*
beta
;
void
*
beta1
;
float
epsilon
;
float
epsilon
;
// Random state.
// Random state.
...
@@ -112,34 +120,46 @@ struct BwdParams : public ParamsBase {
...
@@ -112,34 +120,46 @@ struct BwdParams : public ParamsBase {
BwdParams
()
BwdParams
()
:
ParamsBase
()
:
ParamsBase
()
,
dz
(
nullptr
)
,
dz
(
nullptr
)
,
dz1
(
nullptr
)
,
dx
(
nullptr
)
,
dx
(
nullptr
)
,
dbeta_part
(
nullptr
)
,
dbeta_part
(
nullptr
)
,
dgamma_part
(
nullptr
)
,
dgamma_part
(
nullptr
)
,
dbeta1_part
(
nullptr
)
,
dgamma1_part
(
nullptr
)
,
dcolscale_part
(
nullptr
)
,
dcolscale_part
(
nullptr
)
,
dx0
(
nullptr
)
,
dx0
(
nullptr
)
,
dx1
(
nullptr
)
,
dresidual
(
nullptr
)
,
dresidual
(
nullptr
)
,
dbeta
(
nullptr
)
,
dbeta
(
nullptr
)
,
dgamma
(
nullptr
)
,
dgamma
(
nullptr
)
,
dbeta1
(
nullptr
)
,
dgamma1
(
nullptr
)
,
dcolscale
(
nullptr
)
,
dcolscale
(
nullptr
)
{
{
}
}
// Input: gradient wrt. LN FWD output.
// Input: gradient wrt. LN FWD output.
void
*
dz
;
void
*
dz
;
void
*
dz1
;
// Input: gradient wrt residual.
// Input: gradient wrt residual.
void
*
dx
;
void
*
dx
;
// Workspace for Wgrad pre-reduction.
// Workspace for Wgrad pre-reduction.
void
*
dbeta_part
;
void
*
dbeta_part
;
void
*
dgamma_part
;
void
*
dgamma_part
;
void
*
dbeta1_part
;
void
*
dgamma1_part
;
void
*
dcolscale_part
;
void
*
dcolscale_part
;
// Output: Dgrad.
// Output: Dgrad.
void
*
dx0
;
void
*
dx0
;
void
*
dx1
;
void
*
dresidual
;
void
*
dresidual
;
// Output: Wgrad.
// Output: Wgrad.
void
*
dbeta
;
void
*
dbeta
;
void
*
dgamma
;
void
*
dgamma
;
void
*
dbeta1
;
void
*
dgamma1
;
void
*
dcolscale
;
void
*
dcolscale
;
};
};
...
@@ -152,8 +172,8 @@ using FunctionKey = uint64_t;
...
@@ -152,8 +172,8 @@ using FunctionKey = uint64_t;
using
FwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
FwdFunction
>
;
using
FwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
FwdFunction
>
;
using
BwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
BwdFunction
>
;
using
BwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
BwdFunction
>
;
extern
FwdRegistry
FWD_FUNCS
;
extern
FwdRegistry
FWD_FUNCS
,
PARALLEL_
FWD_FUNCS
;
extern
BwdRegistry
BWD_FUNCS
;
extern
BwdRegistry
BWD_FUNCS
,
PARALLEL_
BWD_FUNCS
;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -238,4 +258,24 @@ struct BwdRegistrar{
...
@@ -238,4 +258,24 @@ struct BwdRegistrar{
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
R
,
typename
O
,
typename
C
,
uint64_t
HIDDEN_SIZE
>
struct
FwdParallelRegistrar
{
FwdParallelRegistrar
(
FwdFunction
f
){
uint64_t
key
=
Types2Key
<
W
,
I
,
R
,
O
,
C
>::
get
(
HIDDEN_SIZE
);
PARALLEL_FWD_FUNCS
.
insert
({
key
,
f
});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
R
,
typename
O
,
typename
C
,
uint64_t
HIDDEN_SIZE
>
struct
BwdParallelRegistrar
{
BwdParallelRegistrar
(
BwdFunction
f
){
uint64_t
key
=
Types2Key
<
W
,
I
,
R
,
O
,
C
>::
get
(
HIDDEN_SIZE
);
PARALLEL_BWD_FUNCS
.
insert
({
key
,
f
});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace layer_norm
}
// namespace layer_norm
csrc/layer_norm/ln_api.cpp
View file @
393882bc
This diff is collapsed.
Click to expand it.
csrc/layer_norm/ln_bwd_7168.cu
0 → 100644
View file @
393882bc
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_bwd_8192.cu
0 → 100644
View file @
393882bc
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_fwd_7168.cu
0 → 100644
View file @
393882bc
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_fwd_8192.cu
0 → 100644
View file @
393882bc
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
csrc/layer_norm/ln_parallel_bwd_1024.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_parallel_bwd_1280.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_parallel_bwd_1536.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
csrc/layer_norm/ln_parallel_bwd_2048.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_256.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_parallel_bwd_2560.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
csrc/layer_norm/ln_parallel_bwd_3072.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_4096.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
// Use 8 warps otherwise there's a lot of register spilling
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_512.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_parallel_bwd_5120.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
// Use 8 warps otherwise there's a lot of register spilling
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_6144.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_7168.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_768.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment