Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
eb33e587
Commit
eb33e587
authored
Jan 19, 2023
by
Tri Dao
Browse files
[LayerNorm] Rename x1 -> residual
parent
f68d41ec
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
89 additions
and
88 deletions
+89
-88
csrc/layer_norm/ln.h
csrc/layer_norm/ln.h
+3
-3
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+15
-15
csrc/layer_norm/ln_bwd_kernels.cuh
csrc/layer_norm/ln_bwd_kernels.cuh
+4
-4
csrc/layer_norm/ln_fwd_kernels.cuh
csrc/layer_norm/ln_fwd_kernels.cuh
+5
-5
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-2
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+48
-48
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+12
-11
No files found.
csrc/layer_norm/ln.h
View file @
eb33e587
...
...
@@ -59,7 +59,7 @@ struct ParamsBase {
// Common data pointers.
void
*
x0
;
void
*
x1
;
void
*
residual
;
void
*
x
;
void
*
dmask
;
void
*
mu
;
...
...
@@ -117,7 +117,7 @@ struct BwdParams : public ParamsBase {
,
dgamma_part
(
nullptr
)
,
dcolscale_part
(
nullptr
)
,
dx0
(
nullptr
)
,
d
x1
(
nullptr
)
,
d
residual
(
nullptr
)
,
dbeta
(
nullptr
)
,
dgamma
(
nullptr
)
,
dcolscale
(
nullptr
)
...
...
@@ -136,7 +136,7 @@ struct BwdParams : public ParamsBase {
// Output: Dgrad.
void
*
dx0
;
void
*
d
x1
;
void
*
d
residual
;
// Output: Wgrad.
void
*
dbeta
;
void
*
dgamma
;
...
...
csrc/layer_norm/ln_api.cpp
View file @
eb33e587
...
...
@@ -81,7 +81,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
////////////////////////////////////////////////////////////////////////////////////////////////////
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_fwd
(
const
at
::
Tensor
&
x0
,
// Input: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x1_
,
// Residual: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
residual_
,
// Residual: BxSxhidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
beta_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
...
...
@@ -97,8 +97,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
bool
is_rms_norm
=
false
)
{
auto
itype
=
x0
.
scalar_type
();
auto
rtype
=
x1
_
.
has_value
()
?
x1
_
.
value
().
scalar_type
()
auto
rtype
=
residual
_
.
has_value
()
?
residual
_
.
value
().
scalar_type
()
:
(
residual_in_fp32
?
torch
::
kFloat32
:
x0
.
scalar_type
());
auto
wtype
=
gamma
.
scalar_type
();
auto
otype
=
itype
;
...
...
@@ -129,11 +129,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
}
if
(
x1
_
.
has_value
())
{
auto
x1
=
x1
_
.
value
();
TORCH_CHECK
(
x1
.
is_cuda
())
TORCH_CHECK
(
x1
.
is_contiguous
());
TORCH_CHECK
(
x1
.
sizes
()
==
sizes
);
if
(
residual
_
.
has_value
())
{
auto
residual
=
residual
_
.
value
();
TORCH_CHECK
(
residual
.
is_cuda
())
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
residual
.
sizes
()
==
sizes
);
}
if
(
rowscale_
.
has_value
())
{
...
...
@@ -178,7 +178,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
opts
=
x0
.
options
();
bool
save_x
=
x1
_
.
has_value
()
||
(
dropout_p
>
0.
f
)
||
rowscale_
.
has_value
()
||
colscale_
.
has_value
()
||
x0_subset_
.
has_value
()
||
(
itype
!=
rtype
);
bool
save_x
=
residual
_
.
has_value
()
||
(
dropout_p
>
0.
f
)
||
rowscale_
.
has_value
()
||
colscale_
.
has_value
()
||
x0_subset_
.
has_value
()
||
(
itype
!=
rtype
);
at
::
Tensor
x
;
if
(
save_x
)
{
x
=
torch
::
empty
(
sizes
,
opts
.
dtype
(
rtype
));
}
at
::
Tensor
dmask
;
...
...
@@ -194,7 +194,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
dropout_p
<
1.
f
);
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
x1
=
x1
_
.
has_value
()
?
x1
_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
residual
=
residual
_
.
has_value
()
?
residual
_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
colscale
=
colscale_
.
has_value
()
?
colscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
x0_subset
=
x0_subset_
.
has_value
()
?
x0_subset_
.
value
().
data_ptr
()
:
nullptr
;
...
...
@@ -383,8 +383,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto
opts
=
x
.
options
();
auto
dx0
=
torch
::
empty
(
x0_sizes
,
opts
.
dtype
(
itype
));
at
::
Tensor
d
x1
;
if
(
has_residual
)
{
d
x1
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
rtype
));
}
at
::
Tensor
d
residual
;
if
(
has_residual
)
{
d
residual
=
torch
::
empty_like
(
x
,
opts
.
dtype
(
rtype
));
}
auto
dgamma
=
torch
::
empty_like
(
gamma
);
auto
dbeta
=
torch
::
empty_like
(
gamma
);
at
::
Tensor
dcolscale
;
...
...
@@ -397,7 +397,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params
.
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dropout_p
<
1.
f
);
launch_params
.
params
.
dropout_keep_p
=
1.
f
-
dropout_p
;
launch_params
.
params
.
d
x1
=
has_residual
?
d
x1
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
d
residual
=
has_residual
?
d
residual
.
data_ptr
()
:
nullptr
;
launch_params
.
params
.
rowscale
=
rowscale_
.
has_value
()
?
rowscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
colscale
=
colscale_
.
has_value
()
?
colscale_
.
value
().
data_ptr
()
:
nullptr
;
launch_params
.
params
.
x0_subset
=
x0_subset_
.
has_value
()
?
x0_subset_
.
value
().
data_ptr
()
:
nullptr
;
...
...
@@ -450,7 +450,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launcher
(
launch_params
,
false
);
std
::
vector
<
at
::
Tensor
>
result
=
{
dx0
,
d
x1
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
};
std
::
vector
<
at
::
Tensor
>
result
=
{
dx0
,
d
residual
,
dgamma
,
dbeta
,
dgamma_part
,
dbeta_part
};
if
(
colscale_
.
has_value
())
{
result
.
push_back
(
dcolscale
);
result
.
push_back
(
dcolscale_part
);
...
...
@@ -462,7 +462,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"CUDA DropoutAddLayerNorm"
;
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
,
py
::
arg
(
"x0"
),
py
::
arg
(
"
x1
"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"beta"
),
py
::
arg
(
"x0"
),
py
::
arg
(
"
residual
"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"beta"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"epsilon"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"z_numrows"
),
py
::
arg
(
"gen_"
),
py
::
arg
(
"residual_in_fp32"
)
=
false
,
py
::
arg
(
"is_rms_norm"
)
=
false
);
...
...
csrc/layer_norm/ln_bwd_kernels.cuh
View file @
eb33e587
...
...
@@ -37,7 +37,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
extern
__shared__
char
smem_
[];
const
bool
has_residual
=
params
.
d
x1
!=
nullptr
;
const
bool
has_residual
=
params
.
d
residual
!=
nullptr
;
const
bool
prenorm
=
params
.
dx
!=
nullptr
;
const
index_t
tidx
=
threadIdx
.
x
;
...
...
@@ -164,7 +164,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ivec
dx0
;
Rvec
d
x1
;
Rvec
d
residual
;
Ivec
x0
;
if
(
Has_colscale
&&
save_dx0
)
{
x0
.
load_from
(
params
.
x0
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
#pragma unroll
...
...
@@ -178,7 +178,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
}
else
{
dx_tmp_res
=
prenorm
?
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
0.
f
;
}
if
(
has_residual
)
{
d
x1
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
if
(
has_residual
)
{
d
residual
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
if
(
save_dx0
)
{
compute_t
dx0_tmp_res
=
dx_tmp_res
*
rowscale_val
;
if
(
Is_dropout
)
{
...
...
@@ -199,7 +199,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
}
}
}
if
(
has_residual
)
{
d
x1
.
store_to
(
params
.
d
x1
,
idx_x
);
}
if
(
has_residual
)
{
d
residual
.
store_to
(
params
.
d
residual
,
idx_x
);
}
if
(
save_dx0
)
{
dx0
.
store_to
(
params
.
dx0
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
idx_x
+=
Ktraits
::
VEC_COLS_PER_LDG
;
idx_x0
+=
Ktraits
::
VEC_COLS_PER_LDG
;
...
...
csrc/layer_norm/ln_fwd_kernels.cuh
View file @
eb33e587
...
...
@@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
using
Stats
=
typename
Ktraits
::
Stats
;
using
stats_t
=
typename
Stats
::
stats_t
;
const
bool
has_residual
=
params
.
x1
!=
nullptr
;
const
bool
has_residual
=
params
.
residual
!=
nullptr
;
const
bool
save_x
=
has_residual
||
Is_dropout
||
Has_colscale
||
(
params
.
rowscale
!=
nullptr
)
||
Has_subset
||
!
(
std
::
is_same
<
input_t
,
residual_t
>::
value
);
extern
__shared__
char
smem_
[];
...
...
@@ -111,11 +111,11 @@ void ln_fwd_kernel(FwdParams params) {
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ivec
x0
;
Rvec
x1
;
Rvec
residual
;
Rvec
x
;
Mvec
dmask
;
if
(
load_x0
)
{
x0
.
load_from
(
params
.
x0
,
!
Has_subset
?
idx_x
:
idx_x0
);
}
if
(
has_residual
)
{
x1
.
load_from
(
params
.
x1
,
idx_x
);
}
if
(
has_residual
)
{
residual
.
load_from
(
params
.
residual
,
idx_x
);
}
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
...
...
@@ -127,9 +127,9 @@ void ln_fwd_kernel(FwdParams params) {
compute_t
x0_ij
=
compute_t
(
x0
.
data
.
elt
[
jt
])
*
rowscale_val
;
x0_ij
=
keep
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.0
f
;
if
(
Has_colscale
)
{
x0_ij
*=
compute_t
(
colscale
[
it
].
data
.
elt
[
jt
]);
}
x_ij
=
has_residual
?
x0_ij
+
compute_t
(
x1
.
data
.
elt
[
jt
])
:
x0_ij
;
x_ij
=
has_residual
?
x0_ij
+
compute_t
(
residual
.
data
.
elt
[
jt
])
:
x0_ij
;
}
else
{
x_ij
=
has_residual
?
compute_t
(
x1
.
data
.
elt
[
jt
])
:
0.
f
;
x_ij
=
has_residual
?
compute_t
(
residual
.
data
.
elt
[
jt
])
:
0.
f
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
...
...
flash_attn/models/gpt.py
View file @
eb33e587
...
...
@@ -292,7 +292,7 @@ class GPTModel(GPTPreTrainedModel):
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
else
:
# Set prenorm=False here since we don't need
to
the residual
# Set prenorm=False here since we don't need the residual
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
...
...
@@ -359,7 +359,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
if
'transformer.ln_0.weight'
in
state_dict
:
n_layers
=
self
.
config
.
num_hidden_
layers
n_layers
=
len
(
self
.
transformer
.
layers
)
ln_weight
=
state_dict
.
pop
(
f
'transformer.layers.
{
n_layers
-
1
}
.norm2.weight'
)
ln_bias
=
state_dict
.
pop
(
f
'transformer.layers.
{
n_layers
-
1
}
.norm2.bias'
)
state_dict
[
'transformer.ln_f.weight'
]
=
ln_weight
...
...
flash_attn/ops/layer_norm.py
View file @
eb33e587
...
...
@@ -7,20 +7,20 @@ from torch.nn import init
import
dropout_layer_norm
def
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
def
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1
mat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
x0mat
,
residual
mat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and
x1
is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and
residual
is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
...
...
@@ -28,7 +28,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
dropout_p
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) +
x1
was not returned in the fwd).
(x = drop(x0) +
residual
was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size
=
gamma
.
numel
()
...
...
@@ -39,34 +39,34 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
d
x1
mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dx0mat
,
d
residual
mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
)
# d
x1
mat is None if not has_residual
# d
residual
mat is None if not has_residual
if
colscale
is
None
:
return
dx0mat
,
d
x1
mat
,
dgamma
,
dbeta
return
dx0mat
,
d
residual
mat
,
dgamma
,
dbeta
else
:
dcolscale
=
rest
[
0
]
return
dx0mat
,
d
x1
mat
,
dgamma
,
dbeta
,
dcolscale
return
dx0mat
,
d
residual
mat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
def
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
"""
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
x0_subset
=
x0_subset
.
view
(
-
1
)
if
x0_subset
is
not
None
else
None
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1
mat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0mat
,
residual
mat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and
x1
is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and
residual
is None and residual_dtype != input_dtype
return
zmat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask
,
mu
,
rsigma
...
...
@@ -75,7 +75,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
x0_numrows
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) +
x1
was not returned in the fwd).
(x = drop(x0) +
residual
was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size
=
gamma
.
numel
()
...
...
@@ -87,30 +87,30 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
if
colscale
is
not
None
:
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
d
x1
mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dx0mat
,
d
residual
mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
)
# d
x1
mat is None if not has_residual
# d
residual
mat is None if not has_residual
if
colscale
is
None
:
return
dx0mat
,
d
x1
mat
,
dgamma
,
dbeta
return
dx0mat
,
d
residual
mat
,
dgamma
,
dbeta
else
:
dcolscale
=
rest
[
0
]
return
dx0mat
,
d
x1
mat
,
dgamma
,
dbeta
,
dcolscale
return
dx0mat
,
d
residual
mat
,
dgamma
,
dbeta
,
dcolscale
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
residual
=
residual
.
contiguous
()
if
residual
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
if
beta
is
not
None
else
None
rowscale
=
rowscale
.
contiguous
()
if
rowscale
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
...
...
@@ -118,7 +118,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
x1
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
if
not
return_dmask
:
...
...
@@ -140,29 +140,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
d
x1
mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dx0mat
,
d
residual
mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
ctx
.
is_rms_norm
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
d
x1
=
dx1
mat
.
view
(
x
.
shape
)
if
d
x1
mat
is
not
None
else
None
d
residual
=
dresidual
mat
.
view
(
x
.
shape
)
if
d
residual
mat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
d
x1
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
dx0
,
d
residual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
)
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
residual
=
residual
.
contiguous
()
if
residual
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
if
beta
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
...
...
@@ -174,7 +174,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx
.
dropout_p
=
dropout_p
ctx
.
rowscale_const
=
rowscale_const
ctx
.
x0_numrows
=
x0
.
shape
[:
-
1
].
numel
()
ctx
.
has_residual
=
x1
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
...
...
@@ -197,42 +197,42 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
d
x1
mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dx0mat
,
d
residual
mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
)
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
d
x1
=
dx1
mat
.
view
(
x
.
shape
)
if
d
x1
mat
is
not
None
else
None
d
residual
=
dresidual
mat
.
view
(
x
.
shape
)
if
d
residual
mat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
d
x1
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
dx0
,
d
residual
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
def
dropout_add_layer_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
def
dropout_add_layer_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if
x1
is None.
Otherwise residual dtype is
x1
.dtype.
"""residual_in_fp32 only has an effect if
residual
is None.
Otherwise residual dtype is
residual
.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
)
def
dropout_add_layer_norm_subset
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
def
dropout_add_layer_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if
x1
is None.
Otherwise residual dtype is
x1
.dtype.
"""residual_in_fp32 only has an effect if
residual
is None.
Otherwise residual dtype is
residual
.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
)
...
...
@@ -254,7 +254,7 @@ class DropoutAddLayerNorm(torch.nn.Module):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x0
,
x1
=
None
):
return
dropout_add_layer_norm
(
x0
,
x1
,
self
.
weight
,
self
.
bias
,
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_layer_norm
(
x0
,
residual
,
self
.
weight
,
self
.
bias
,
self
.
p
if
self
.
training
else
0.0
,
self
.
epsilon
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
flash_attn/ops/rms_norm.py
View file @
eb33e587
...
...
@@ -12,26 +12,27 @@ def rms_norm(x, weight, epsilon):
False
,
True
)
def
dropout_add_rms_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
def
dropout_add_rms_norm
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
x0
,
residual
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
)
def
dropout_add_rms_norm_subset
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
def
dropout_add_rms_norm_subset
(
x0
,
residual
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if
x1
is None.
Otherwise residual dtype is
x1
.dtype.
"""residual_in_fp32 only has an effect if
residual
is None.
Otherwise residual dtype is
residual
.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
residual
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
)
...
...
@@ -52,7 +53,7 @@ class DropoutAddRMSNorm(torch.nn.Module):
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x0
,
x1
=
None
):
return
dropout_add_rms_norm
(
x0
,
x1
,
self
.
weight
,
None
,
def
forward
(
self
,
x0
,
residual
=
None
):
return
dropout_add_rms_norm
(
x0
,
residual
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
epsilon
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment