Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6738d947
"vscode:/vscode.git/clone" did not exist on "f85b3ea8ebdb4292d3490f7d18b8019ba93b6787"
Commit
6738d947
authored
Jan 06, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement RMS Norm
parent
a1f49a2b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
205 additions
and
80 deletions
+205
-80
csrc/layer_norm/README.md
csrc/layer_norm/README.md
+1
-0
csrc/layer_norm/ln.h
csrc/layer_norm/ln.h
+3
-0
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+26
-10
csrc/layer_norm/ln_bwd_kernels.cuh
csrc/layer_norm/ln_bwd_kernels.cuh
+2
-2
csrc/layer_norm/ln_fwd_kernels.cuh
csrc/layer_norm/ln_fwd_kernels.cuh
+7
-3
csrc/layer_norm/ln_utils.cuh
csrc/layer_norm/ln_utils.cuh
+7
-0
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+30
-22
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+58
-0
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+71
-43
No files found.
csrc/layer_norm/README.md
View file @
6738d947
...
@@ -2,6 +2,7 @@ This CUDA extension implements fused dropout + residual + LayerNorm, building on
...
@@ -2,6 +2,7 @@ This CUDA extension implements fused dropout + residual + LayerNorm, building on
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
We also implement RMSNorm as an option.
If you want to use it for dimensions larger than 6k, please file an issue.
If you want to use it for dimensions larger than 6k, please file an issue.
...
...
csrc/layer_norm/ln.h
View file @
6738d947
...
@@ -44,6 +44,7 @@ struct ParamsBase {
...
@@ -44,6 +44,7 @@ struct ParamsBase {
,
colscale
(
nullptr
)
,
colscale
(
nullptr
)
,
dropout_keep_p
(
1.
f
)
,
dropout_keep_p
(
1.
f
)
,
dropout_scale
(
1.
f
)
,
dropout_scale
(
1.
f
)
,
is_rms_norm
(
false
)
,
workspace
(
nullptr
)
,
workspace
(
nullptr
)
,
barrier
(
nullptr
)
,
barrier
(
nullptr
)
{
{
...
@@ -75,6 +76,8 @@ struct ParamsBase {
...
@@ -75,6 +76,8 @@ struct ParamsBase {
float
dropout_scale
;
float
dropout_scale
;
float
rowscale_const
;
float
rowscale_const
;
bool
is_rms_norm
;
// Multi-CTA workspace in gmem.
// Multi-CTA workspace in gmem.
void
*
workspace
;
void
*
workspace
;
...
...
csrc/layer_norm/ln_api.cpp
View file @
6738d947
...
@@ -83,7 +83,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
...
@@ -83,7 +83,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_fwd
(
const
at
::
Tensor
&
x0
,
// Input: BxSxhidden_size
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_fwd
(
const
at
::
Tensor
&
x0
,
// Input: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x1_
,
// Residual: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x1_
,
// Residual: BxSxhidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
beta
_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x0_subset_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
x0_subset_
,
// BxS
...
@@ -93,7 +93,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -93,7 +93,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const
float
rowscale_const
,
const
float
rowscale_const
,
const
int64_t
z_numrows
,
const
int64_t
z_numrows
,
c10
::
optional
<
at
::
Generator
>
gen_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
bool
residual_in_fp32
bool
residual_in_fp32
=
false
,
bool
is_rms_norm
=
false
)
{
)
{
auto
itype
=
x0
.
scalar_type
();
auto
itype
=
x0
.
scalar_type
();
auto
rtype
=
x1_
.
has_value
()
auto
rtype
=
x1_
.
has_value
()
...
@@ -104,11 +105,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -104,11 +105,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
ctype
=
torch
::
kFloat32
;
auto
ctype
=
torch
::
kFloat32
;
auto
mtype
=
torch
::
kUInt8
;
auto
mtype
=
torch
::
kUInt8
;
TORCH_CHECK
(
beta
.
dtype
()
==
wtype
);
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
TORCH_CHECK
(
beta
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_contiguous
());
TORCH_CHECK
(
x0
.
is_contiguous
());
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
...
@@ -123,6 +121,14 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -123,6 +121,14 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const
int
cols
=
sizes
[
1
];
const
int
cols
=
sizes
[
1
];
auto
hidden_size
=
gamma
.
numel
();
auto
hidden_size
=
gamma
.
numel
();
if
(
beta_
.
has_value
())
{
auto
beta
=
beta_
.
value
();
TORCH_CHECK
(
beta
.
dtype
()
==
wtype
);
TORCH_CHECK
(
beta
.
is_cuda
())
TORCH_CHECK
(
beta
.
is_contiguous
());
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
}
if
(
x1_
.
has_value
())
{
if
(
x1_
.
has_value
())
{
auto
x1
=
x1_
.
value
();
auto
x1
=
x1_
.
value
();
TORCH_CHECK
(
x1
.
is_cuda
())
TORCH_CHECK
(
x1
.
is_cuda
())
...
@@ -161,7 +167,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -161,7 +167,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
}
}
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
...
@@ -218,12 +223,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -218,12 +223,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params
.
mu
=
mu
.
data_ptr
();
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
beta
=
beta
.
data_ptr
();
params
.
beta
=
beta
_
.
has_value
()
?
beta_
.
value
()
.
data_ptr
()
:
nullptr
;
params
.
z
=
z
.
data_ptr
();
params
.
z
=
z
.
data_ptr
();
params
.
epsilon
=
epsilon
;
params
.
epsilon
=
epsilon
;
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
rowscale_const
=
rowscale_const
;
params
.
rowscale_const
=
rowscale_const
;
params
.
is_rms_norm
=
is_rms_norm
;
if
(
dropout_p
>
0.
f
)
{
if
(
dropout_p
>
0.
f
)
{
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
...
@@ -268,7 +274,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -268,7 +274,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
const
float
dropout_p
,
const
float
dropout_p
,
const
float
rowscale_const
,
const
float
rowscale_const
,
const
int64_t
x0_numrows
,
const
int64_t
x0_numrows
,
const
bool
has_residual
const
bool
has_residual
,
bool
is_rms_norm
=
false
)
{
)
{
auto
itype
=
dz
.
scalar_type
();
auto
itype
=
dz
.
scalar_type
();
...
@@ -431,6 +438,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -431,6 +438,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
rowscale_const
=
rowscale_const
;
params
.
rowscale_const
=
rowscale_const
;
params
.
is_rms_norm
=
is_rms_norm
;
if
(
launch_params
.
barrier_size
>
0
)
{
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
// TODO Any way to avoid this?
...
@@ -453,6 +461,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
...
@@ -453,6 +461,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"CUDA DropoutAddLayerNorm"
;
m
.
doc
()
=
"CUDA DropoutAddLayerNorm"
;
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
);
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
,
m
.
def
(
"dropout_add_ln_bwd"
,
&
dropout_add_ln_bwd
,
"Run Dropout + Add + LayerNorm backward kernel"
);
py
::
arg
(
"x0"
),
py
::
arg
(
"x1"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"beta"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"epsilon"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"z_numrows"
),
py
::
arg
(
"gen_"
),
py
::
arg
(
"residual_in_fp32"
)
=
false
,
py
::
arg
(
"is_rms_norm"
)
=
false
);
m
.
def
(
"dropout_add_ln_bwd"
,
&
dropout_add_ln_bwd
,
"Run Dropout + Add + LayerNorm backward kernel"
,
py
::
arg
(
"dz"
),
py
::
arg
(
"dx_"
),
py
::
arg
(
"x"
),
py
::
arg
(
"x0_"
),
py
::
arg
(
"dmask_"
),
py
::
arg
(
"mu"
),
py
::
arg
(
"rsigma"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"x0_numrows"
),
py
::
arg
(
"has_residual"
),
py
::
arg
(
"is_rms_norm"
)
=
false
);
}
}
csrc/layer_norm/ln_bwd_kernels.cuh
View file @
6738d947
...
@@ -125,7 +125,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -125,7 +125,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
#pragma unroll
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
x_tmp
=
x
.
data
.
elt
[
jt
];
compute_t
x_tmp
=
x
.
data
.
elt
[
jt
];
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
mu_r
);
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
(
!
params
.
is_rms_norm
?
mu_r
:
0.
f
)
);
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
])
*
compute_t
(
dz
.
data
.
elt
[
jt
]);
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
])
*
compute_t
(
dz
.
data
.
elt
[
jt
]);
compute_t
dz_tmp
=
dz
.
data
.
elt
[
jt
];
compute_t
dz_tmp
=
dz
.
data
.
elt
[
jt
];
...
@@ -173,7 +173,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
...
@@ -173,7 +173,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if
(
load_dz
)
{
if
(
load_dz
)
{
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
mdy_local
));
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
(
!
params
.
is_rms_norm
?
mdy_local
:
0.
f
)
));
dx_tmp_res
=
prenorm
?
dx_tmp
+
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
dx_tmp
;
dx_tmp_res
=
prenorm
?
dx_tmp
+
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
dx_tmp
;
}
else
{
}
else
{
dx_tmp_res
=
prenorm
?
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
0.
f
;
dx_tmp_res
=
prenorm
?
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
0.
f
;
...
...
csrc/layer_norm/ln_fwd_kernels.cuh
View file @
6738d947
...
@@ -89,7 +89,11 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -89,7 +89,11 @@ void ln_fwd_kernel(FwdParams params) {
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
if
(
params
.
beta
!=
nullptr
)
{
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
}
else
{
beta
[
it
].
zero_
();
}
if
(
Has_colscale
)
{
colscale
[
it
].
load_from
(
params
.
colscale
,
idx
);
}
if
(
Has_colscale
)
{
colscale
[
it
].
load_from
(
params
.
colscale
,
idx
);
}
idx
+=
VEC_COLS_PER_LDG
;
idx
+=
VEC_COLS_PER_LDG
;
}
}
...
@@ -159,7 +163,7 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -159,7 +163,7 @@ void ln_fwd_kernel(FwdParams params) {
mu_ptr
[
row
]
=
mu
;
mu_ptr
[
row
]
=
mu
;
}
}
compute_t
rs
=
rsqrtf
(
m2
*
params
.
inverse_cols
+
params
.
epsilon
);
compute_t
rs
=
rsqrtf
(
m2
*
params
.
inverse_cols
+
params
.
epsilon
+
(
!
params
.
is_rms_norm
?
0.
f
:
mu
*
mu
)
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
rs_ptr
[
row
]
=
rs
;
rs_ptr
[
row
]
=
rs
;
...
@@ -174,7 +178,7 @@ void ln_fwd_kernel(FwdParams params) {
...
@@ -174,7 +178,7 @@ void ln_fwd_kernel(FwdParams params) {
Ovec
z
;
Ovec
z
;
#pragma unroll
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu
));
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
(
!
params
.
is_rms_norm
?
mu
:
0.
f
)
));
compute_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
compute_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
compute_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
compute_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
z
.
data
.
elt
[
jt
]
=
output_t
(
g_ij
*
y_ij
+
b_ij
);
z
.
data
.
elt
[
jt
]
=
output_t
(
g_ij
*
y_ij
+
b_ij
);
...
...
csrc/layer_norm/ln_utils.cuh
View file @
6738d947
...
@@ -308,6 +308,13 @@ struct Vec {
...
@@ -308,6 +308,13 @@ struct Vec {
}
}
}
}
inline
__device__
void
zero_
()
{
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
this
->
data
.
elt
[
it
]
=
Elt_type
(
0.
f
);
}
}
inline
__device__
void
load_from
(
const
void
*
base_ptr
,
const
size_t
idx
)
{
inline
__device__
void
load_from
(
const
void
*
base_ptr
,
const
size_t
idx
)
{
this
->
data
.
vec
=
static_cast
<
const
Vec_type
*>
(
base_ptr
)[
idx
];
this
->
data
.
vec
=
static_cast
<
const
Vec_type
*>
(
base_ptr
)[
idx
];
}
}
...
...
flash_attn/ops/layer_norm.py
View file @
6738d947
...
@@ -8,7 +8,7 @@ import dropout_layer_norm
...
@@ -8,7 +8,7 @@ import dropout_layer_norm
def
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
def
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
):
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
"""
"""
hidden_size
=
gamma
.
numel
()
hidden_size
=
gamma
.
numel
()
...
@@ -17,7 +17,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
...
@@ -17,7 +17,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1mat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
x0mat
,
x1mat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
)
)
# dmask is None if dropout_p == 0.0
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
...
@@ -25,7 +25,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
...
@@ -25,7 +25,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
):
dropout_p
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
(x = drop(x0) + x1 was not returned in the fwd).
...
@@ -41,7 +41,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
...
@@ -41,7 +41,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
1.0
,
0
,
has_residual
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
)
)
# dx1mat is None if not has_residual
# dx1mat is None if not has_residual
if
colscale
is
None
:
if
colscale
is
None
:
...
@@ -53,7 +53,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
...
@@ -53,7 +53,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
def
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
def
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
):
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
"""
"""
hidden_size
=
gamma
.
numel
()
hidden_size
=
gamma
.
numel
()
...
@@ -63,7 +63,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
...
@@ -63,7 +63,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1mat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0mat
,
x1mat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
)
)
# dmask is None if dropout_p == 0.0
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
...
@@ -72,7 +72,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
...
@@ -72,7 +72,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
):
x0_numrows
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
(x = drop(x0) + x1 was not returned in the fwd).
...
@@ -89,7 +89,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
...
@@ -89,7 +89,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
)
)
# dx1mat is None if not has_residual
# dx1mat is None if not has_residual
if
colscale
is
None
:
if
colscale
is
None
:
...
@@ -101,16 +101,17 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
...
@@ -101,16 +101,17 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
p
renorm
=
False
,
return_dmask
=
False
):
re
sidual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_
norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
beta
=
beta
.
contiguous
()
if
beta
is
not
None
else
None
rowscale
=
rowscale
.
contiguous
()
if
rowscale
is
not
None
else
None
rowscale
=
rowscale
.
contiguous
()
if
rowscale
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
)
)
# Only need to save x0 if we need to compute gradient wrt colscale
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x0_saved
=
x0
if
colscale
is
not
None
else
None
...
@@ -118,6 +119,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
...
@@ -118,6 +119,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
ctx
.
prenorm
=
prenorm
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
x1
is
not
None
ctx
.
has_residual
=
x1
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
if
not
return_dmask
:
if
not
return_dmask
:
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
)))
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
)))
...
@@ -138,26 +141,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
...
@@ -138,26 +141,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
dropout_p
=
ctx
.
dropout_p
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
ctx
.
is_rms_norm
)
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
dx0
,
dx1
,
dgamma
,
dbeta
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
return
(
dx0
,
dx1
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
)
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
=
False
,
return_dmask
=
False
):
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
beta
=
beta
.
contiguous
()
if
beta
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
)
)
# Only need to save x0 if we need to compute gradient wrt colscale
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x0_saved
=
x0
if
colscale
is
not
None
else
None
...
@@ -169,6 +175,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
...
@@ -169,6 +175,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx
.
rowscale_const
=
rowscale_const
ctx
.
rowscale_const
=
rowscale_const
ctx
.
x0_numrows
=
x0
.
shape
[:
-
1
].
numel
()
ctx
.
x0_numrows
=
x0
.
shape
[:
-
1
].
numel
()
ctx
.
has_residual
=
x1
is
not
None
ctx
.
has_residual
=
x1
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
if
not
return_dmask
:
if
not
return_dmask
:
return
(
zmat
.
view
(
z_shape
)
if
not
prenorm
return
(
zmat
.
view
(
z_shape
)
if
not
prenorm
...
@@ -191,13 +199,13 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
...
@@ -191,13 +199,13 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
has_residual
=
ctx
.
has_residual
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
)
)
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dx1
,
dgamma
,
dbeta
,
dcolscale
,
None
,
None
,
None
,
Non
e
,
None
,
None
,
None
,
return
(
dx0
,
dx1
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscal
e
,
None
,
None
,
None
,
None
,
None
)
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
...
@@ -212,7 +220,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
...
@@ -212,7 +220,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
"""
"""
return
DropoutAddLayerNormFn
.
apply
(
return
DropoutAddLayerNormFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
x0
,
x1
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
return_dropout_mask
False
,
return_dropout_mask
)
)
...
@@ -225,7 +233,7 @@ def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, laye
...
@@ -225,7 +233,7 @@ def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, laye
"""
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
x1
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
return_dropout_mask
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
)
)
...
...
flash_attn/ops/rms_norm.py
0 → 100644
View file @
6738d947
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
torch
from
torch.nn
import
init
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormFn
,
DropoutAddLayerNormSubsetFn
def
rms_norm
(
x
,
weight
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
None
,
None
,
None
,
0.0
,
epsilon
,
False
,
False
,
True
)
def
dropout_add_rms_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
)
def
dropout_add_rms_norm_subset
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
)
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
epsilon
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x0
,
x1
=
None
):
return
dropout_add_rms_norm
(
x0
,
x1
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
epsilon
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
tests/ops/test_dropout_layer_norm.py
View file @
6738d947
...
@@ -8,11 +8,20 @@ from einops import rearrange, repeat
...
@@ -8,11 +8,20 @@ from einops import rearrange, repeat
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNorm
,
dropout_add_layer_norm
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNorm
,
dropout_add_layer_norm
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_subset
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_subset
from
flash_attn.ops.rms_norm
import
DropoutAddRMSNorm
,
dropout_add_rms_norm
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_subset
try
:
from
apex.normalization
import
FusedRMSNorm
except
:
FusedRMSNorm
=
None
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'is_rms_norm'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_colscale', [False])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_rowscale', [True])
# @pytest.mark.parametrize('has_rowscale', [True])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
...
@@ -26,11 +35,17 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...
@@ -26,11 +35,17 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
(
torch
.
float32
,
torch
.
float32
)]
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
256
])
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
):
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
,
is_rms_norm
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
FusedRMSNorm
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
layer_norm_cls
=
torch
.
nn
.
LayerNorm
if
not
is_rms_norm
else
FusedRMSNorm
our_layer_norm_cls
=
DropoutAddLayerNorm
if
not
is_rms_norm
else
DropoutAddRMSNorm
our_layer_norm_func
=
dropout_add_layer_norm
if
not
is_rms_norm
else
dropout_add_rms_norm
device
=
'cuda'
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
rtol
,
atol
=
(
1e-3
,
1e-4
)
...
@@ -67,20 +82,22 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
...
@@ -67,20 +82,22 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
if
has_colscale
:
if
has_colscale
:
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
torch
.
nn
.
L
ayer
N
orm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
l
ayer
_n
orm
_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
if
not
is_rms_norm
:
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
model_ref
=
layer_norm_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
model
=
our_layer_norm_cls
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
if
not
is_rms_norm
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
dropout_add
_layer_norm
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
out
,
dmask
=
our
_layer_norm
_func
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
assert
out
.
dtype
==
input_dtype
assert
out
.
dtype
==
input_dtype
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
if
has_residual
:
if
has_residual
:
...
@@ -101,7 +118,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
...
@@ -101,7 +118,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
if
has_residual
:
if
has_residual
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
3e-5
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
3e-5
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
3e-5
if
has_colscale
:
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
...
@@ -151,27 +169,34 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
...
@@ -151,27 +169,34 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
# @pytest.mark.parametrize('has_colscale', [True, False])
@
pytest
.
mark
.
parametrize
(
'is_rms_norm'
,
[
False
,
True
])
# @pytest.mark.parametrize('has_rowscale', [True, False])
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_residual', [True, False])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
# @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
# @pytest.mark.parametrize('input_dtype,residual_dtype',
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
# [(torch.float16, torch.float16), (torch.float16, torch.float32),
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
# (torch.float32, torch.float32)]
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
# + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
(
torch
.
float32
,
torch
.
float32
)]
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
])
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
False
])
# @pytest.mark.parametrize('has_colscale', [True])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
])
# @pytest.mark.parametrize('has_rowscale', [False])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
])
# @pytest.mark.parametrize('has_residual', [True])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
])
# @pytest.mark.parametrize('dropout_p', [0.0])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
[(
torch
.
float32
,
torch
.
float32
)])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
256
])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
):
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
,
is_rms_norm
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
FusedRMSNorm
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
layer_norm_cls
=
torch
.
nn
.
LayerNorm
if
not
is_rms_norm
else
FusedRMSNorm
our_layer_norm_cls
=
DropoutAddLayerNorm
if
not
is_rms_norm
else
DropoutAddRMSNorm
our_layer_norm_func
=
dropout_add_layer_norm
if
not
is_rms_norm
else
dropout_add_rms_norm
device
=
'cuda'
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
2e-4
)
rtol
,
atol
=
(
1e-3
,
2e-4
)
...
@@ -208,23 +233,25 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
...
@@ -208,23 +233,25 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if
has_colscale
:
if
has_colscale
:
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
torch
.
nn
.
L
ayer
N
orm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
l
ayer
_n
orm
_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
if
not
is_rms_norm
:
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
model_ref
=
layer_norm_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
dtype
=
weight_dtype
)
model
=
our_layer_norm_cls
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
if
not
is_rms_norm
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
dropout_add
_layer_norm
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
out
,
residual
,
dmask
=
our
_layer_norm
_func
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
prenorm
=
True
,
layerscale
=
colscale
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
return_dropout_mask
=
True
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
if
has_residual
:
if
has_residual
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
x1_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
x1_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
...
@@ -247,7 +274,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
...
@@ -247,7 +274,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if
has_residual
:
if
has_residual
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
not
is_rms_norm
:
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment