Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
393882bc
Commit
393882bc
authored
Mar 29, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement LN with parallel residual, support dim 8k
parent
009a3e71
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
717 additions
and
39 deletions
+717
-39
csrc/layer_norm/README.md
csrc/layer_norm/README.md
+7
-4
csrc/layer_norm/ln.h
csrc/layer_norm/ln.h
+43
-3
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+408
-32
csrc/layer_norm/ln_bwd_7168.cu
csrc/layer_norm/ln_bwd_7168.cu
+15
-0
csrc/layer_norm/ln_bwd_8192.cu
csrc/layer_norm/ln_bwd_8192.cu
+15
-0
csrc/layer_norm/ln_fwd_7168.cu
csrc/layer_norm/ln_fwd_7168.cu
+15
-0
csrc/layer_norm/ln_fwd_8192.cu
csrc/layer_norm/ln_fwd_8192.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_1024.cu
csrc/layer_norm/ln_parallel_bwd_1024.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_1280.cu
csrc/layer_norm/ln_parallel_bwd_1280.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_1536.cu
csrc/layer_norm/ln_parallel_bwd_1536.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_2048.cu
csrc/layer_norm/ln_parallel_bwd_2048.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_256.cu
csrc/layer_norm/ln_parallel_bwd_256.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_2560.cu
csrc/layer_norm/ln_parallel_bwd_2560.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_3072.cu
csrc/layer_norm/ln_parallel_bwd_3072.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_4096.cu
csrc/layer_norm/ln_parallel_bwd_4096.cu
+17
-0
csrc/layer_norm/ln_parallel_bwd_512.cu
csrc/layer_norm/ln_parallel_bwd_512.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_5120.cu
csrc/layer_norm/ln_parallel_bwd_5120.cu
+17
-0
csrc/layer_norm/ln_parallel_bwd_6144.cu
csrc/layer_norm/ln_parallel_bwd_6144.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_7168.cu
csrc/layer_norm/ln_parallel_bwd_7168.cu
+15
-0
csrc/layer_norm/ln_parallel_bwd_768.cu
csrc/layer_norm/ln_parallel_bwd_768.cu
+15
-0
No files found.
csrc/layer_norm/README.md
View file @
393882bc
This CUDA extension implements fused dropout + residual + LayerNorm, building on
This CUDA extension implements fused dropout + residual + LayerNorm, building on
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
Major changes:
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
-
Add dropout and residual.
We also implement RMSNorm as an option.
-
Make it work for both pre-norm and post-norm architecture.
-
Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
-
Implement RMSNorm as an option.
-
Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).
If you want to use it for dimensions larger than
6
k, please file an issue.
If you want to use it for dimensions larger than
8
k, please file an issue.
This extension has only been tested on A100s.
This extension has only been tested on A100s.
...
...
csrc/layer_norm/ln.h
View file @
393882bc
...
@@ -14,7 +14,7 @@ namespace layer_norm {
...
@@ -14,7 +14,7 @@ namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Params
>
template
<
typename
Params
>
struct
LaunchParams
{
struct
LaunchParams
{
size_t
elts_per_thread
;
size_t
elts_per_thread
;
...
@@ -40,6 +40,7 @@ struct ParamsBase {
...
@@ -40,6 +40,7 @@ struct ParamsBase {
,
mu
(
nullptr
)
,
mu
(
nullptr
)
,
rs
(
nullptr
)
,
rs
(
nullptr
)
,
gamma
(
nullptr
)
,
gamma
(
nullptr
)
,
gamma1
(
nullptr
)
,
rowscale
(
nullptr
)
,
rowscale
(
nullptr
)
,
colscale
(
nullptr
)
,
colscale
(
nullptr
)
,
dropout_keep_p
(
1.
f
)
,
dropout_keep_p
(
1.
f
)
...
@@ -59,12 +60,15 @@ struct ParamsBase {
...
@@ -59,12 +60,15 @@ struct ParamsBase {
// Common data pointers.
// Common data pointers.
void
*
x0
;
void
*
x0
;
void
*
x1
;
void
*
residual
;
void
*
residual
;
void
*
x
;
void
*
x
;
void
*
dmask
;
void
*
dmask
;
void
*
dmask1
;
void
*
mu
;
void
*
mu
;
void
*
rs
;
void
*
rs
;
void
*
gamma
;
void
*
gamma
;
void
*
gamma1
;
void
*
rowscale
;
void
*
rowscale
;
void
*
colscale
;
void
*
colscale
;
void
*
x0_subset
;
void
*
x0_subset
;
...
@@ -92,14 +96,18 @@ struct FwdParams : public ParamsBase {
...
@@ -92,14 +96,18 @@ struct FwdParams : public ParamsBase {
FwdParams
()
FwdParams
()
:
ParamsBase
()
:
ParamsBase
()
,
z
(
nullptr
)
,
z
(
nullptr
)
,
z1
(
nullptr
)
,
beta
(
nullptr
)
,
beta
(
nullptr
)
,
beta1
(
nullptr
)
,
epsilon
(
0.
f
)
,
epsilon
(
0.
f
)
{
{
}
}
// Output of LN FWD.
// Output of LN FWD.
void
*
z
;
void
*
z
;
void
*
z1
;
void
*
beta
;
void
*
beta
;
void
*
beta1
;
float
epsilon
;
float
epsilon
;
// Random state.
// Random state.
...
@@ -112,34 +120,46 @@ struct BwdParams : public ParamsBase {
...
@@ -112,34 +120,46 @@ struct BwdParams : public ParamsBase {
BwdParams
()
BwdParams
()
:
ParamsBase
()
:
ParamsBase
()
,
dz
(
nullptr
)
,
dz
(
nullptr
)
,
dz1
(
nullptr
)
,
dx
(
nullptr
)
,
dx
(
nullptr
)
,
dbeta_part
(
nullptr
)
,
dbeta_part
(
nullptr
)
,
dgamma_part
(
nullptr
)
,
dgamma_part
(
nullptr
)
,
dbeta1_part
(
nullptr
)
,
dgamma1_part
(
nullptr
)
,
dcolscale_part
(
nullptr
)
,
dcolscale_part
(
nullptr
)
,
dx0
(
nullptr
)
,
dx0
(
nullptr
)
,
dx1
(
nullptr
)
,
dresidual
(
nullptr
)
,
dresidual
(
nullptr
)
,
dbeta
(
nullptr
)
,
dbeta
(
nullptr
)
,
dgamma
(
nullptr
)
,
dgamma
(
nullptr
)
,
dbeta1
(
nullptr
)
,
dgamma1
(
nullptr
)
,
dcolscale
(
nullptr
)
,
dcolscale
(
nullptr
)
{
{
}
}
// Input: gradient wrt. LN FWD output.
// Input: gradient wrt. LN FWD output.
void
*
dz
;
void
*
dz
;
void
*
dz1
;
// Input: gradient wrt residual.
// Input: gradient wrt residual.
void
*
dx
;
void
*
dx
;
// Workspace for Wgrad pre-reduction.
// Workspace for Wgrad pre-reduction.
void
*
dbeta_part
;
void
*
dbeta_part
;
void
*
dgamma_part
;
void
*
dgamma_part
;
void
*
dbeta1_part
;
void
*
dgamma1_part
;
void
*
dcolscale_part
;
void
*
dcolscale_part
;
// Output: Dgrad.
// Output: Dgrad.
void
*
dx0
;
void
*
dx0
;
void
*
dx1
;
void
*
dresidual
;
void
*
dresidual
;
// Output: Wgrad.
// Output: Wgrad.
void
*
dbeta
;
void
*
dbeta
;
void
*
dgamma
;
void
*
dgamma
;
void
*
dbeta1
;
void
*
dgamma1
;
void
*
dcolscale
;
void
*
dcolscale
;
};
};
...
@@ -152,8 +172,8 @@ using FunctionKey = uint64_t;
...
@@ -152,8 +172,8 @@ using FunctionKey = uint64_t;
using
FwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
FwdFunction
>
;
using
FwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
FwdFunction
>
;
using
BwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
BwdFunction
>
;
using
BwdRegistry
=
std
::
unordered_map
<
FunctionKey
,
BwdFunction
>
;
extern
FwdRegistry
FWD_FUNCS
;
extern
FwdRegistry
FWD_FUNCS
,
PARALLEL_
FWD_FUNCS
;
extern
BwdRegistry
BWD_FUNCS
;
extern
BwdRegistry
BWD_FUNCS
,
PARALLEL_
BWD_FUNCS
;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -238,4 +258,24 @@ struct BwdRegistrar{
...
@@ -238,4 +258,24 @@ struct BwdRegistrar{
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
R
,
typename
O
,
typename
C
,
uint64_t
HIDDEN_SIZE
>
struct
FwdParallelRegistrar
{
FwdParallelRegistrar
(
FwdFunction
f
){
uint64_t
key
=
Types2Key
<
W
,
I
,
R
,
O
,
C
>::
get
(
HIDDEN_SIZE
);
PARALLEL_FWD_FUNCS
.
insert
({
key
,
f
});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
W
,
typename
I
,
typename
R
,
typename
O
,
typename
C
,
uint64_t
HIDDEN_SIZE
>
struct
BwdParallelRegistrar
{
BwdParallelRegistrar
(
BwdFunction
f
){
uint64_t
key
=
Types2Key
<
W
,
I
,
R
,
O
,
C
>::
get
(
HIDDEN_SIZE
);
PARALLEL_BWD_FUNCS
.
insert
({
key
,
f
});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace layer_norm
}
// namespace layer_norm
csrc/layer_norm/ln_api.cpp
View file @
393882bc
...
@@ -28,8 +28,8 @@ namespace layer_norm {
...
@@ -28,8 +28,8 @@ namespace layer_norm {
// Create registries and provide runtime versions of config hash functions.
// Create registries and provide runtime versions of config hash functions.
FwdRegistry
FWD_FUNCS
;
FwdRegistry
FWD_FUNCS
,
PARALLEL_
FWD_FUNCS
;
BwdRegistry
BWD_FUNCS
;
BwdRegistry
BWD_FUNCS
,
PARALLEL_
BWD_FUNCS
;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -80,6 +80,28 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
...
@@ -80,6 +80,28 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm
::
FwdFunction
&
get_parallel_fwd_launcher
(
torch
::
Dtype
wtype
,
torch
::
Dtype
itype
,
torch
::
Dtype
rtype
,
torch
::
Dtype
otype
,
torch
::
Dtype
ctype
,
uint32_t
hidden_size
)
{
auto
iter
=
layer_norm
::
PARALLEL_FWD_FUNCS
.
find
(
layer_norm
::
get_key
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
hidden_size
));
if
(
iter
!=
layer_norm
::
PARALLEL_FWD_FUNCS
.
end
()
)
{
return
iter
->
second
;
}
else
{
TORCH_CHECK
(
false
,
"FWD: Unsupported hidden_size or types: "
,
hidden_size
,
wtype
,
itype
,
rtype
,
otype
,
ctype
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm
::
BwdFunction
&
get_parallel_bwd_launcher
(
torch
::
Dtype
wtype
,
torch
::
Dtype
itype
,
torch
::
Dtype
rtype
,
torch
::
Dtype
otype
,
torch
::
Dtype
ctype
,
uint32_t
hidden_size
)
{
auto
iter
=
layer_norm
::
PARALLEL_BWD_FUNCS
.
find
(
layer_norm
::
get_key
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
hidden_size
));
if
(
iter
!=
layer_norm
::
PARALLEL_BWD_FUNCS
.
end
()
)
{
return
iter
->
second
;
}
else
{
TORCH_CHECK
(
false
,
"BWD: Unsupported hidden_size or types: "
,
hidden_size
,
wtype
,
itype
,
rtype
,
otype
,
ctype
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_fwd
(
const
at
::
Tensor
&
x0
,
// Input: BxSxhidden_size
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_fwd
(
const
at
::
Tensor
&
x0
,
// Input: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
residual_
,
// Residual: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
residual_
,
// Residual: BxSxhidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
...
@@ -105,8 +127,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -105,8 +127,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
ctype
=
torch
::
kFloat32
;
auto
ctype
=
torch
::
kFloat32
;
auto
mtype
=
torch
::
kUInt8
;
auto
mtype
=
torch
::
kUInt8
;
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_cuda
())
;
TORCH_CHECK
(
gamma
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
;
TORCH_CHECK
(
x0
.
is_contiguous
());
TORCH_CHECK
(
x0
.
is_contiguous
());
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
...
@@ -120,25 +142,26 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -120,25 +142,26 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const
int
rows
=
sizes
[
0
];
const
int
rows
=
sizes
[
0
];
const
int
cols
=
sizes
[
1
];
const
int
cols
=
sizes
[
1
];
auto
hidden_size
=
gamma
.
numel
();
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
hidden_size
==
cols
);
if
(
beta_
.
has_value
())
{
if
(
beta_
.
has_value
())
{
auto
beta
=
beta_
.
value
();
auto
beta
=
beta_
.
value
();
TORCH_CHECK
(
beta
.
dtype
()
==
wtype
);
TORCH_CHECK
(
beta
.
dtype
()
==
wtype
);
TORCH_CHECK
(
beta
.
is_cuda
())
TORCH_CHECK
(
beta
.
is_cuda
())
;
TORCH_CHECK
(
beta
.
is_contiguous
());
TORCH_CHECK
(
beta
.
is_contiguous
());
TORCH_CHECK
(
gamm
a
.
sizes
()
==
bet
a
.
sizes
());
TORCH_CHECK
(
bet
a
.
sizes
()
==
gamm
a
.
sizes
());
}
}
if
(
residual_
.
has_value
())
{
if
(
residual_
.
has_value
())
{
auto
residual
=
residual_
.
value
();
auto
residual
=
residual_
.
value
();
TORCH_CHECK
(
residual
.
is_cuda
())
TORCH_CHECK
(
residual
.
is_cuda
())
;
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
residual
.
sizes
()
==
sizes
);
TORCH_CHECK
(
residual
.
sizes
()
==
sizes
);
}
}
if
(
rowscale_
.
has_value
())
{
if
(
rowscale_
.
has_value
())
{
auto
rowscale
=
rowscale_
.
value
();
auto
rowscale
=
rowscale_
.
value
();
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_cuda
())
;
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
rowscale
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
...
@@ -146,7 +169,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -146,7 +169,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
if
(
colscale_
.
has_value
())
{
if
(
colscale_
.
has_value
())
{
auto
colscale
=
colscale_
.
value
();
auto
colscale
=
colscale_
.
value
();
TORCH_CHECK
(
colscale
.
is_cuda
())
TORCH_CHECK
(
colscale
.
is_cuda
())
;
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
sizes
()
==
c10
::
IntArrayRef
{
cols
});
TORCH_CHECK
(
colscale
.
sizes
()
==
c10
::
IntArrayRef
{
cols
});
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
...
@@ -154,7 +177,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -154,7 +177,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
if
(
x0_subset_
.
has_value
())
{
if
(
x0_subset_
.
has_value
())
{
auto
x0_subset
=
x0_subset_
.
value
();
auto
x0_subset
=
x0_subset_
.
value
();
TORCH_CHECK
(
x0_subset
.
is_cuda
())
TORCH_CHECK
(
x0_subset
.
is_cuda
())
;
TORCH_CHECK
(
x0_subset
.
is_contiguous
());
TORCH_CHECK
(
x0_subset
.
is_contiguous
());
TORCH_CHECK
(
x0_subset
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
x0_subset
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
x0_subset
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
x0_subset
.
dtype
()
==
torch
::
kInt32
);
...
@@ -167,9 +190,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -167,9 +190,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
}
}
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
8192
));
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
(
epsilon
>=
0.
f
);
TORCH_CHECK
(
epsilon
>=
0.
f
);
// Otherwise the kernel will be launched from cuda:0 device
// Otherwise the kernel will be launched from cuda:0 device
...
@@ -306,6 +327,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -306,6 +327,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto
cols
=
sizes
[
1
];
auto
cols
=
sizes
[
1
];
TORCH_CHECK
(
dz
.
dim
()
==
2
);
TORCH_CHECK
(
dz
.
dim
()
==
2
);
TORCH_CHECK
(
dz
.
size
(
1
)
==
cols
);
TORCH_CHECK
(
dz
.
size
(
1
)
==
cols
);
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
(
hidden_size
==
cols
);
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
// Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
...
@@ -316,7 +339,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -316,7 +339,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
if
(
dx_
.
has_value
())
{
if
(
dx_
.
has_value
())
{
auto
dx
=
dx_
.
value
();
auto
dx
=
dx_
.
value
();
TORCH_CHECK
(
dx
.
dtype
()
==
rtype
);
TORCH_CHECK
(
dx
.
dtype
()
==
rtype
);
TORCH_CHECK
(
dx
.
is_cuda
())
TORCH_CHECK
(
dx
.
is_cuda
())
;
TORCH_CHECK
(
dx
.
is_contiguous
());
TORCH_CHECK
(
dx
.
is_contiguous
());
TORCH_CHECK
(
dx
.
sizes
()
==
sizes
);
TORCH_CHECK
(
dx
.
sizes
()
==
sizes
);
}
}
...
@@ -331,7 +354,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -331,7 +354,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
if
(
rowscale_
.
has_value
())
{
if
(
rowscale_
.
has_value
())
{
auto
rowscale
=
rowscale_
.
value
();
auto
rowscale
=
rowscale_
.
value
();
TORCH_CHECK
(
rowscale
.
is_cuda
())
TORCH_CHECK
(
rowscale
.
is_cuda
())
;
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
is_contiguous
());
TORCH_CHECK
(
rowscale
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
rowscale
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
TORCH_CHECK
(
rowscale
.
dtype
()
==
itype
);
...
@@ -339,14 +362,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -339,14 +362,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
if
(
colscale_
.
has_value
())
{
if
(
colscale_
.
has_value
())
{
auto
colscale
=
colscale_
.
value
();
auto
colscale
=
colscale_
.
value
();
TORCH_CHECK
(
colscale
.
is_cuda
())
TORCH_CHECK
(
colscale
.
is_cuda
())
;
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
is_contiguous
());
TORCH_CHECK
(
colscale
.
sizes
()
==
c10
::
IntArrayRef
{
cols
});
TORCH_CHECK
(
colscale
.
sizes
()
==
c10
::
IntArrayRef
{
cols
});
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
TORCH_CHECK
(
colscale
.
dtype
()
==
wtype
);
TORCH_CHECK
(
x0_
.
has_value
());
TORCH_CHECK
(
x0_
.
has_value
());
auto
x0
=
x0_
.
value
();
auto
x0
=
x0_
.
value
();
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_cuda
())
;
TORCH_CHECK
(
x0
.
is_contiguous
());
TORCH_CHECK
(
x0
.
is_contiguous
());
TORCH_CHECK
(
x0
.
sizes
()
==
x0_sizes
);
TORCH_CHECK
(
x0
.
sizes
()
==
x0_sizes
);
TORCH_CHECK
(
x0
.
dtype
()
==
itype
);
TORCH_CHECK
(
x0
.
dtype
()
==
itype
);
...
@@ -354,7 +377,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -354,7 +377,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
if
(
x0_subset_
.
has_value
())
{
if
(
x0_subset_
.
has_value
())
{
auto
x0_subset
=
x0_subset_
.
value
();
auto
x0_subset
=
x0_subset_
.
value
();
TORCH_CHECK
(
x0_subset
.
is_cuda
())
TORCH_CHECK
(
x0_subset
.
is_cuda
())
;
TORCH_CHECK
(
x0_subset
.
is_contiguous
());
TORCH_CHECK
(
x0_subset
.
is_contiguous
());
TORCH_CHECK
(
x0_subset
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
x0_subset
.
sizes
()
==
c10
::
IntArrayRef
{
rows
});
TORCH_CHECK
(
x0_subset
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
x0_subset
.
dtype
()
==
torch
::
kInt32
);
...
@@ -367,9 +390,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -367,9 +390,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
}
}
auto
hidden_size
=
gamma
.
numel
();
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
8192
));
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
...
@@ -457,18 +478,373 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -457,18 +478,373 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
}
}
return
result
;
return
result
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_parallel_residual_fwd
(
const
at
::
Tensor
&
x0
,
// Input: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x1_
,
// Input: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
residual_
,
// Residual: BxSxhidden_size
const
at
::
Tensor
&
gamma0
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
beta0_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
gamma1_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
beta1_
,
// hidden_size
const
float
dropout_p
,
const
float
epsilon
,
c10
::
optional
<
at
::
Generator
>
gen_
,
bool
residual_in_fp32
=
false
,
bool
is_rms_norm
=
false
)
{
auto
itype
=
x0
.
scalar_type
();
auto
rtype
=
residual_
.
has_value
()
?
residual_
.
value
().
scalar_type
()
:
(
residual_in_fp32
?
torch
::
kFloat32
:
x0
.
scalar_type
());
auto
wtype
=
gamma0
.
scalar_type
();
auto
otype
=
itype
;
auto
ctype
=
torch
::
kFloat32
;
auto
mtype
=
torch
::
kUInt8
;
TORCH_CHECK
(
x0
.
is_cuda
());
TORCH_CHECK
(
gamma0
.
is_cuda
());
TORCH_CHECK
(
x0
.
is_contiguous
());
const
auto
sizes
=
x0
.
sizes
();
TORCH_CHECK
(
x0
.
dim
()
==
2
);
const
int
rows
=
sizes
[
0
];
const
int
cols
=
sizes
[
1
];
auto
hidden_size
=
gamma0
.
numel
();
TORCH_CHECK
(
hidden_size
==
cols
);
if
(
x1_
.
has_value
())
{
auto
x1
=
x1_
.
value
();
TORCH_CHECK
(
x1
.
is_cuda
());
TORCH_CHECK
(
x1
.
is_contiguous
());
TORCH_CHECK
(
x1
.
sizes
()
==
sizes
);
}
if
(
residual_
.
has_value
())
{
auto
residual
=
residual_
.
value
();
TORCH_CHECK
(
residual
.
is_cuda
());
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
residual
.
sizes
()
==
sizes
);
}
if
(
beta0_
.
has_value
())
{
auto
beta0
=
beta0_
.
value
();
TORCH_CHECK
(
beta0
.
dtype
()
==
wtype
);
TORCH_CHECK
(
beta0
.
is_cuda
());
TORCH_CHECK
(
beta0
.
is_contiguous
());
TORCH_CHECK
(
beta0
.
sizes
()
==
gamma0
.
sizes
());
}
if
(
gamma1_
.
has_value
())
{
auto
gamma1
=
gamma1_
.
value
();
TORCH_CHECK
(
gamma1
.
dtype
()
==
wtype
);
TORCH_CHECK
(
gamma1
.
is_cuda
());
TORCH_CHECK
(
gamma1
.
is_contiguous
());
TORCH_CHECK
(
gamma1
.
sizes
()
==
gamma0
.
sizes
());
}
if
(
beta1_
.
has_value
())
{
auto
beta1
=
beta1_
.
value
();
TORCH_CHECK
(
beta1
.
dtype
()
==
wtype
);
TORCH_CHECK
(
beta1
.
is_cuda
());
TORCH_CHECK
(
beta1
.
is_contiguous
());
TORCH_CHECK
(
beta1
.
sizes
()
==
gamma0
.
sizes
());
}
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
8192
));
TORCH_CHECK
(
epsilon
>=
0.
f
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x0
.
get_device
()};
auto
opts
=
x0
.
options
();
bool
save_x
=
residual_
.
has_value
()
||
x1_
.
has_value
()
||
(
dropout_p
>
0.
f
)
||
(
itype
!=
rtype
);
at
::
Tensor
x
;
if
(
save_x
)
{
x
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
rtype
));
}
at
::
Tensor
dmask0
,
dmask1
;
if
(
dropout_p
>
0.
f
)
{
dmask0
=
torch
::
empty
(
x0
.
sizes
(),
opts
.
dtype
(
mtype
));
if
(
x1_
.
has_value
())
{
dmask1
=
torch
::
empty
(
x0
.
sizes
(),
opts
.
dtype
(
mtype
));
}
};
auto
z0
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
otype
));
at
::
Tensor
z1
;
if
(
gamma1_
.
has_value
())
{
z1
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
otype
));
}
auto
mu
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
auto
rsigma
=
torch
::
empty
({
rows
},
opts
.
dtype
(
ctype
));
layer_norm
::
LaunchParams
<
layer_norm
::
FwdParams
>
launch_params
;
launch_params
.
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
dropout_p
<
1.
f
);
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
residual
=
residual_
.
has_value
()
?
residual_
.
value
().
data_ptr
()
:
nullptr
;
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
// Request the kernel launcher.
auto
launcher
=
get_parallel_fwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
at
::
Tensor
workspace
,
barrier
;
// Set the kernel runtime parameters.
layer_norm
::
FwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
cols
=
cols
;
params
.
x0
=
x0
.
data_ptr
();
params
.
x1
=
x1_
.
has_value
()
?
x1_
.
value
().
data_ptr
()
:
nullptr
;
params
.
x
=
save_x
?
x
.
data_ptr
()
:
nullptr
;
params
.
dmask
=
dropout_p
>
0.
f
?
dmask0
.
data_ptr
()
:
nullptr
;
params
.
dmask1
=
(
dropout_p
>
0.
f
&&
x1_
.
has_value
())
?
dmask1
.
data_ptr
()
:
nullptr
;
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma0
.
data_ptr
();
params
.
gamma1
=
gamma1_
.
has_value
()
?
gamma1_
.
value
().
data_ptr
()
:
nullptr
;
params
.
beta
=
beta0_
.
has_value
()
?
beta0_
.
value
().
data_ptr
()
:
nullptr
;
params
.
beta1
=
beta1_
.
has_value
()
?
beta1_
.
value
().
data_ptr
()
:
nullptr
;
params
.
z
=
z0
.
data_ptr
();
params
.
z1
=
gamma1_
.
has_value
()
?
z1
.
data_ptr
()
:
nullptr
;
params
.
epsilon
=
epsilon
;
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
is_rms_norm
=
is_rms_norm
;
if
(
dropout_p
>
0.
f
)
{
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t
counter_offset
=
2
*
launch_params
.
elts_per_thread
;
// See Note [Acquire lock when using random generators]
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
}
if
(
launch_params
.
barrier_size
>
0
)
{
auto
options
=
x0
.
options
();
barrier
=
torch
::
zeros
(
launch_params
.
barrier_size
,
options
.
dtype
(
torch
::
kInt32
));
workspace
=
torch
::
empty
(
launch_params
.
workspace_bytes
,
options
.
dtype
(
torch
::
kChar
));
params
.
workspace
=
workspace
.
data_ptr
();
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
}
// Launch the kernel.
launcher
(
launch_params
,
false
);
return
{
z0
,
z1
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_parallel_residual_bwd
(
const
at
::
Tensor
&
dz0
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
dz1_
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
dx_
,
// BxSxhidden_size
const
at
::
Tensor
&
x
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
dmask0_
,
// BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
dmask1_
,
// BxSxhidden_size
const
at
::
Tensor
&
mu
,
// BxS, FP32!
const
at
::
Tensor
&
rsigma
,
// BxS, FP32!
const
at
::
Tensor
&
gamma0
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
gamma1_
,
// hidden_size
const
float
dropout_p
,
const
bool
has_x1
,
const
bool
has_residual
,
bool
is_rms_norm
=
false
)
{
auto
itype
=
dz0
.
scalar_type
();
auto
rtype
=
x
.
scalar_type
();
auto
wtype
=
gamma0
.
scalar_type
();
auto
otype
=
itype
;
auto
ctype
=
torch
::
kFloat32
;
auto
mtype
=
torch
::
kUInt8
;
if
(
dropout_p
>
0.
f
)
{
TORCH_CHECK
(
dmask0_
.
has_value
());
}
TORCH_CHECK
(
dz0
.
dtype
()
==
otype
);
TORCH_CHECK
(
dz0
.
dtype
()
==
otype
);
TORCH_CHECK
(
mu
.
dtype
()
==
ctype
);
TORCH_CHECK
(
rsigma
.
dtype
()
==
ctype
);
TORCH_CHECK
(
x
.
is_cuda
());
TORCH_CHECK
(
dz0
.
is_cuda
());
TORCH_CHECK
(
mu
.
is_cuda
());
TORCH_CHECK
(
rsigma
.
is_cuda
());
TORCH_CHECK
(
gamma0
.
is_cuda
());
TORCH_CHECK
(
x
.
is_contiguous
());
TORCH_CHECK
(
dz0
.
is_contiguous
());
auto
sizes
=
x
.
sizes
();
TORCH_CHECK
(
sizes
.
size
()
==
2
);
auto
rows
=
sizes
[
0
];
auto
cols
=
sizes
[
1
];
TORCH_CHECK
(
dz0
.
dim
()
==
2
);
TORCH_CHECK
(
dz0
.
size
(
1
)
==
cols
);
auto
hidden_size
=
gamma0
.
numel
();
TORCH_CHECK
(
hidden_size
==
cols
);
if
(
dz1_
.
has_value
())
{
auto
dz1
=
dz1_
.
value
();
TORCH_CHECK
(
dz1
.
dtype
()
==
otype
);
TORCH_CHECK
(
dz1
.
is_cuda
());
TORCH_CHECK
(
dz1
.
is_contiguous
());
TORCH_CHECK
(
dz1
.
sizes
()
==
sizes
);
TORCH_CHECK
(
gamma1_
.
has_value
());
auto
gamma1
=
gamma1_
.
value
();
TORCH_CHECK
(
gamma1
.
dtype
()
==
wtype
);
TORCH_CHECK
(
gamma1
.
is_cuda
());
TORCH_CHECK
(
gamma1
.
is_contiguous
());
TORCH_CHECK
(
gamma1
.
sizes
()
==
gamma0
.
sizes
());
}
if
(
dx_
.
has_value
())
{
auto
dx
=
dx_
.
value
();
TORCH_CHECK
(
dx
.
dtype
()
==
rtype
);
TORCH_CHECK
(
dx
.
is_cuda
());
TORCH_CHECK
(
dx
.
is_contiguous
());
TORCH_CHECK
(
dx
.
sizes
()
==
sizes
);
}
if
(
dmask0_
.
has_value
())
{
auto
dmask0
=
dmask0_
.
value
();
TORCH_CHECK
(
dmask0
.
dtype
()
==
mtype
);
TORCH_CHECK
(
dmask0
.
is_cuda
());
TORCH_CHECK
(
dmask0
.
is_contiguous
());
TORCH_CHECK
(
dmask0
.
sizes
()
==
sizes
);
if
(
has_x1
)
{
TORCH_CHECK
(
dmask1_
.
has_value
());
auto
dmask1
=
dmask1_
.
value
();
TORCH_CHECK
(
dmask1
.
dtype
()
==
mtype
);
TORCH_CHECK
(
dmask1
.
is_cuda
());
TORCH_CHECK
(
dmask1
.
is_contiguous
());
TORCH_CHECK
(
dmask1
.
sizes
()
==
sizes
);
}
}
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
8192
));
TORCH_CHECK
(
mu
.
numel
()
==
rows
);
TORCH_CHECK
(
mu
.
sizes
()
==
rsigma
.
sizes
());
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
dz0
.
get_device
()};
auto
opts
=
x
.
options
();
auto
dx0
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
itype
));
at
::
Tensor
dx1
;
if
(
has_x1
)
{
dx1
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
itype
));
}
at
::
Tensor
dresidual
;
if
(
has_residual
)
{
dresidual
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
rtype
));
}
auto
dgamma0
=
torch
::
empty_like
(
gamma0
);
auto
dbeta0
=
torch
::
empty_like
(
gamma0
);
at
::
Tensor
dgamma1
,
dbeta1
;
if
(
gamma1_
.
has_value
())
{
dgamma1
=
torch
::
empty_like
(
gamma0
);
dbeta1
=
torch
::
empty_like
(
gamma0
);
}
layer_norm
::
LaunchParams
<
layer_norm
::
BwdParams
>
launch_params
;
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
launch_params
.
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dropout_p
<
1.
f
);
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
dresidual
=
has_residual
?
dresidual
.
data_ptr
()
:
nullptr
;
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
multiple
=
hidden_size
<=
1536
?
256
:
(
hidden_size
<=
3072
?
512
:
1024
);
auto
launcher
=
get_parallel_bwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
launcher
(
launch_params
,
true
);
auto
dgamma0_part
=
torch
::
zeros
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
auto
dbeta0_part
=
torch
::
zeros
({
launch_params
.
params
.
ctas_per_col
,
hidden_size
},
opts
.
dtype
(
ctype
));
at
::
Tensor
dgamma1_part
,
dbeta1_part
;
if
(
gamma1_
.
has_value
())
{
dgamma1_part
=
torch
::
zeros_like
(
dgamma0_part
);
dbeta1_part
=
torch
::
zeros_like
(
dbeta0_part
);
}
at
::
Tensor
workspace
,
barrier
;
layer_norm
::
BwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
cols
=
cols
;
params
.
x
=
x
.
data_ptr
();
params
.
dmask
=
dropout_p
>
0.
f
?
dmask0_
.
value
().
data_ptr
()
:
nullptr
;
params
.
dmask1
=
(
dropout_p
>
0.
f
&&
has_x1
)
?
dmask1_
.
value
().
data_ptr
()
:
nullptr
;
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma0
.
data_ptr
();
params
.
gamma1
=
gamma1_
.
has_value
()
?
gamma1_
.
value
().
data_ptr
()
:
nullptr
;
params
.
dz
=
dz0
.
data_ptr
();
params
.
dz1
=
dz1_
.
has_value
()
?
dz1_
.
value
().
data_ptr
()
:
nullptr
;
params
.
dx
=
dx_
.
has_value
()
?
dx_
.
value
().
data_ptr
()
:
nullptr
;
params
.
dx0
=
dx0
.
data_ptr
();
params
.
dx1
=
has_x1
?
dx1
.
data_ptr
()
:
nullptr
;
params
.
dbeta
=
dbeta0
.
data_ptr
();
params
.
dgamma
=
dgamma0
.
data_ptr
();
params
.
dbeta1
=
gamma1_
.
has_value
()
?
dbeta1
.
data_ptr
()
:
nullptr
;
params
.
dgamma1
=
gamma1_
.
has_value
()
?
dgamma1
.
data_ptr
()
:
nullptr
;
params
.
dbeta_part
=
dbeta0_part
.
data_ptr
();
params
.
dgamma_part
=
dgamma0_part
.
data_ptr
();
params
.
dbeta1_part
=
gamma1_
.
has_value
()
?
dbeta1_part
.
data_ptr
()
:
nullptr
;
params
.
dgamma1_part
=
gamma1_
.
has_value
()
?
dgamma1_part
.
data_ptr
()
:
nullptr
;
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
is_rms_norm
=
is_rms_norm
;
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
barrier
=
torch
::
zeros
(
launch_params
.
barrier_size
,
opts
.
dtype
(
torch
::
kInt32
));
workspace
=
torch
::
empty
(
launch_params
.
workspace_bytes
,
opts
.
dtype
(
torch
::
kChar
));
params
.
workspace
=
workspace
.
data_ptr
();
params
.
barrier
=
barrier
.
data_ptr
<
int
>
();
}
launcher
(
launch_params
,
false
);
std
::
vector
<
at
::
Tensor
>
result
=
{
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
dgamma0_part
,
dbeta0_part
,
dgamma1_part
,
dbeta1_part
};
return
result
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"CUDA DropoutAddLayerNorm"
;
m
.
doc
()
=
"CUDA DropoutAddLayerNorm"
;
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
,
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
,
py
::
arg
(
"x0"
),
py
::
arg
(
"residual"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"beta"
),
py
::
arg
(
"x0"
),
py
::
arg
(
"residual"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"beta_"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"epsilon"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"z_numrows"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"epsilon"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"z_numrows"
),
py
::
arg
(
"gen_"
),
py
::
arg
(
"residual_in_fp32"
)
=
false
,
py
::
arg
(
"is_rms_norm"
)
=
false
);
py
::
arg
(
"gen_"
),
py
::
arg
(
"residual_in_fp32"
)
=
false
,
py
::
arg
(
"is_rms_norm"
)
=
false
);
m
.
def
(
"dropout_add_ln_bwd"
,
&
dropout_add_ln_bwd
,
"Run Dropout + Add + LayerNorm backward kernel"
,
m
.
def
(
"dropout_add_ln_bwd"
,
&
dropout_add_ln_bwd
,
"Run Dropout + Add + LayerNorm backward kernel"
,
py
::
arg
(
"dz"
),
py
::
arg
(
"dx_"
),
py
::
arg
(
"x"
),
py
::
arg
(
"x0_"
),
py
::
arg
(
"dmask_"
),
py
::
arg
(
"mu"
),
py
::
arg
(
"dz"
),
py
::
arg
(
"dx_"
),
py
::
arg
(
"x"
),
py
::
arg
(
"x0_"
),
py
::
arg
(
"dmask_"
),
py
::
arg
(
"mu"
),
py
::
arg
(
"rsigma"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"rsigma"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"x0_numrows"
),
py
::
arg
(
"has_residual"
),
py
::
arg
(
"is_rms_norm"
)
=
false
);
py
::
arg
(
"x0_numrows"
),
py
::
arg
(
"has_residual"
),
py
::
arg
(
"is_rms_norm"
)
=
false
);
m
.
def
(
"dropout_add_ln_parallel_residual_fwd"
,
&
dropout_add_ln_parallel_residual_fwd
,
"Run Dropout + Add + LayerNorm parallel residual forward kernel"
,
py
::
arg
(
"x0"
),
py
::
arg
(
"x1_"
),
py
::
arg
(
"residual"
),
py
::
arg
(
"gamma0"
),
py
::
arg
(
"beta0_"
),
py
::
arg
(
"gamma1_"
),
py
::
arg
(
"beta1_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"epsilon"
),
py
::
arg
(
"gen_"
),
py
::
arg
(
"residual_in_fp32"
)
=
false
,
py
::
arg
(
"is_rms_norm"
)
=
false
);
m
.
def
(
"dropout_add_ln_parallel_residual_bwd"
,
&
dropout_add_ln_parallel_residual_bwd
,
"Run Dropout + Add + LayerNorm parallel residual backward kernel"
,
py
::
arg
(
"dz0"
),
py
::
arg
(
"dz1_"
),
py
::
arg
(
"dx_"
),
py
::
arg
(
"x"
),
py
::
arg
(
"dmask0_"
),
py
::
arg
(
"dmask1_"
),
py
::
arg
(
"mu"
),
py
::
arg
(
"rsigma"
),
py
::
arg
(
"gamma0"
),
py
::
arg
(
"gamma1_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"has_x1"
),
py
::
arg
(
"has_residual"
),
py
::
arg
(
"is_rms_norm"
)
=
false
);
}
}
csrc/layer_norm/ln_bwd_7168.cu
0 → 100644
View file @
393882bc
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_BWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_bwd_8192.cu
0 → 100644
View file @
393882bc
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_fwd_7168.cu
0 → 100644
View file @
393882bc
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_FWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_fwd_8192.cu
0 → 100644
View file @
393882bc
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
csrc/layer_norm/ln_parallel_bwd_1024.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_parallel_bwd_1280.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_parallel_bwd_1536.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
csrc/layer_norm/ln_parallel_bwd_2048.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_256.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
256
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_parallel_bwd_2560.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
8
,
4
);
csrc/layer_norm/ln_parallel_bwd_3072.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_4096.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
// Use 8 warps otherwise there's a lot of register spilling
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_512.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
512
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
csrc/layer_norm/ln_parallel_bwd_5120.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
// Use 8 warps otherwise there's a lot of register spilling
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_6144.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_7168.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
8
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_bwd_768.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
);
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment