Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6738d947
Commit
6738d947
authored
Jan 06, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement RMS Norm
parent
a1f49a2b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
205 additions
and
80 deletions
+205
-80
csrc/layer_norm/README.md
csrc/layer_norm/README.md
+1
-0
csrc/layer_norm/ln.h
csrc/layer_norm/ln.h
+3
-0
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+26
-10
csrc/layer_norm/ln_bwd_kernels.cuh
csrc/layer_norm/ln_bwd_kernels.cuh
+2
-2
csrc/layer_norm/ln_fwd_kernels.cuh
csrc/layer_norm/ln_fwd_kernels.cuh
+7
-3
csrc/layer_norm/ln_utils.cuh
csrc/layer_norm/ln_utils.cuh
+7
-0
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+30
-22
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+58
-0
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+71
-43
No files found.
csrc/layer_norm/README.md
View file @
6738d947
...
...
@@ -2,6 +2,7 @@ This CUDA extension implements fused dropout + residual + LayerNorm, building on
Apex's
[
FastLayerNorm
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm
)
.
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
We also implement RMSNorm as an option.
If you want to use it for dimensions larger than 6k, please file an issue.
...
...
csrc/layer_norm/ln.h
View file @
6738d947
...
...
@@ -44,6 +44,7 @@ struct ParamsBase {
,
colscale
(
nullptr
)
,
dropout_keep_p
(
1.
f
)
,
dropout_scale
(
1.
f
)
,
is_rms_norm
(
false
)
,
workspace
(
nullptr
)
,
barrier
(
nullptr
)
{
...
...
@@ -75,6 +76,8 @@ struct ParamsBase {
float
dropout_scale
;
float
rowscale_const
;
bool
is_rms_norm
;
// Multi-CTA workspace in gmem.
void
*
workspace
;
...
...
csrc/layer_norm/ln_api.cpp
View file @
6738d947
...
...
@@ -83,7 +83,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
std
::
vector
<
at
::
Tensor
>
dropout_add_ln_fwd
(
const
at
::
Tensor
&
x0
,
// Input: BxSxhidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x1_
,
// Residual: BxSxhidden_size
const
at
::
Tensor
&
gamma
,
// hidden_size
const
at
::
Tensor
&
beta
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
beta
_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
rowscale_
,
// BxS
c10
::
optional
<
const
at
::
Tensor
>
&
colscale_
,
// hidden_size
c10
::
optional
<
const
at
::
Tensor
>
&
x0_subset_
,
// BxS
...
...
@@ -93,7 +93,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const
float
rowscale_const
,
const
int64_t
z_numrows
,
c10
::
optional
<
at
::
Generator
>
gen_
,
bool
residual_in_fp32
bool
residual_in_fp32
=
false
,
bool
is_rms_norm
=
false
)
{
auto
itype
=
x0
.
scalar_type
();
auto
rtype
=
x1_
.
has_value
()
...
...
@@ -104,11 +105,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto
ctype
=
torch
::
kFloat32
;
auto
mtype
=
torch
::
kUInt8
;
TORCH_CHECK
(
beta
.
dtype
()
==
wtype
);
TORCH_CHECK
(
x0
.
is_cuda
())
TORCH_CHECK
(
gamma
.
is_cuda
())
TORCH_CHECK
(
beta
.
is_cuda
())
TORCH_CHECK
(
x0
.
is_contiguous
());
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
...
...
@@ -123,6 +121,14 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const
int
cols
=
sizes
[
1
];
auto
hidden_size
=
gamma
.
numel
();
if
(
beta_
.
has_value
())
{
auto
beta
=
beta_
.
value
();
TORCH_CHECK
(
beta
.
dtype
()
==
wtype
);
TORCH_CHECK
(
beta
.
is_cuda
())
TORCH_CHECK
(
beta
.
is_contiguous
());
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
}
if
(
x1_
.
has_value
())
{
auto
x1
=
x1_
.
value
();
TORCH_CHECK
(
x1
.
is_cuda
())
...
...
@@ -161,7 +167,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK
(
z_subset
.
dtype
()
==
torch
::
kInt32
);
}
TORCH_CHECK
(
gamma
.
sizes
()
==
beta
.
sizes
());
TORCH_CHECK
(
hidden_size
==
cols
);
TORCH_CHECK
((
hidden_size
%
8
==
0
)
&&
(
hidden_size
<=
6144
));
...
...
@@ -218,12 +223,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params
.
mu
=
mu
.
data_ptr
();
params
.
rs
=
rsigma
.
data_ptr
();
params
.
gamma
=
gamma
.
data_ptr
();
params
.
beta
=
beta
.
data_ptr
();
params
.
beta
=
beta
_
.
has_value
()
?
beta_
.
value
()
.
data_ptr
()
:
nullptr
;
params
.
z
=
z
.
data_ptr
();
params
.
epsilon
=
epsilon
;
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
rowscale_const
=
rowscale_const
;
params
.
is_rms_norm
=
is_rms_norm
;
if
(
dropout_p
>
0.
f
)
{
// number of times random will be generated per thread, to offset philox counter in thc random
...
...
@@ -268,7 +274,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
const
float
dropout_p
,
const
float
rowscale_const
,
const
int64_t
x0_numrows
,
const
bool
has_residual
const
bool
has_residual
,
bool
is_rms_norm
=
false
)
{
auto
itype
=
dz
.
scalar_type
();
...
...
@@ -431,6 +438,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params
.
dropout_scale
=
1.
f
/
(
1.
f
-
dropout_p
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
rowscale_const
=
rowscale_const
;
params
.
is_rms_norm
=
is_rms_norm
;
if
(
launch_params
.
barrier_size
>
0
)
{
// TODO Any way to avoid this?
...
...
@@ -453,6 +461,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"CUDA DropoutAddLayerNorm"
;
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
);
m
.
def
(
"dropout_add_ln_bwd"
,
&
dropout_add_ln_bwd
,
"Run Dropout + Add + LayerNorm backward kernel"
);
m
.
def
(
"dropout_add_ln_fwd"
,
&
dropout_add_ln_fwd
,
"Run Dropout + Add + LayerNorm forward kernel"
,
py
::
arg
(
"x0"
),
py
::
arg
(
"x1"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"beta"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"epsilon"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"z_numrows"
),
py
::
arg
(
"gen_"
),
py
::
arg
(
"residual_in_fp32"
)
=
false
,
py
::
arg
(
"is_rms_norm"
)
=
false
);
m
.
def
(
"dropout_add_ln_bwd"
,
&
dropout_add_ln_bwd
,
"Run Dropout + Add + LayerNorm backward kernel"
,
py
::
arg
(
"dz"
),
py
::
arg
(
"dx_"
),
py
::
arg
(
"x"
),
py
::
arg
(
"x0_"
),
py
::
arg
(
"dmask_"
),
py
::
arg
(
"mu"
),
py
::
arg
(
"rsigma"
),
py
::
arg
(
"gamma"
),
py
::
arg
(
"rowscale_"
),
py
::
arg
(
"colscale_"
),
py
::
arg
(
"x0_subset_"
),
py
::
arg
(
"z_subset_"
),
py
::
arg
(
"dropout_p"
),
py
::
arg
(
"rowscale_const"
),
py
::
arg
(
"x0_numrows"
),
py
::
arg
(
"has_residual"
),
py
::
arg
(
"is_rms_norm"
)
=
false
);
}
csrc/layer_norm/ln_bwd_kernels.cuh
View file @
6738d947
...
...
@@ -125,7 +125,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
x_tmp
=
x
.
data
.
elt
[
jt
];
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
mu_r
);
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
(
!
params
.
is_rms_norm
?
mu_r
:
0.
f
)
);
compute_t
dy_tmp
=
compute_t
(
gamma
[
it
].
data
.
elt
[
jt
])
*
compute_t
(
dz
.
data
.
elt
[
jt
]);
compute_t
dz_tmp
=
dz
.
data
.
elt
[
jt
];
...
...
@@ -173,7 +173,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if
(
load_dz
)
{
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
mdy_local
));
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
(
!
params
.
is_rms_norm
?
mdy_local
:
0.
f
)
));
dx_tmp_res
=
prenorm
?
dx_tmp
+
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
dx_tmp
;
}
else
{
dx_tmp_res
=
prenorm
?
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
0.
f
;
...
...
csrc/layer_norm/ln_fwd_kernels.cuh
View file @
6738d947
...
...
@@ -89,7 +89,11 @@ void ln_fwd_kernel(FwdParams params) {
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
gamma
[
it
].
load_from
(
params
.
gamma
,
idx
);
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
if
(
params
.
beta
!=
nullptr
)
{
beta
[
it
].
load_from
(
params
.
beta
,
idx
);
}
else
{
beta
[
it
].
zero_
();
}
if
(
Has_colscale
)
{
colscale
[
it
].
load_from
(
params
.
colscale
,
idx
);
}
idx
+=
VEC_COLS_PER_LDG
;
}
...
...
@@ -159,7 +163,7 @@ void ln_fwd_kernel(FwdParams params) {
mu_ptr
[
row
]
=
mu
;
}
compute_t
rs
=
rsqrtf
(
m2
*
params
.
inverse_cols
+
params
.
epsilon
);
compute_t
rs
=
rsqrtf
(
m2
*
params
.
inverse_cols
+
params
.
epsilon
+
(
!
params
.
is_rms_norm
?
0.
f
:
mu
*
mu
)
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
rs_ptr
[
row
]
=
rs
;
...
...
@@ -174,7 +178,7 @@ void ln_fwd_kernel(FwdParams params) {
Ovec
z
;
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
mu
));
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
(
!
params
.
is_rms_norm
?
mu
:
0.
f
)
));
compute_t
g_ij
=
gamma
[
it
].
data
.
elt
[
jt
];
compute_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
z
.
data
.
elt
[
jt
]
=
output_t
(
g_ij
*
y_ij
+
b_ij
);
...
...
csrc/layer_norm/ln_utils.cuh
View file @
6738d947
...
...
@@ -308,6 +308,13 @@ struct Vec {
}
}
inline
__device__
void
zero_
()
{
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
this
->
data
.
elt
[
it
]
=
Elt_type
(
0.
f
);
}
}
inline
__device__
void
load_from
(
const
void
*
base_ptr
,
const
size_t
idx
)
{
this
->
data
.
vec
=
static_cast
<
const
Vec_type
*>
(
base_ptr
)[
idx
];
}
...
...
flash_attn/ops/layer_norm.py
View file @
6738d947
...
...
@@ -8,7 +8,7 @@ import dropout_layer_norm
def
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
):
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
"""
hidden_size
=
gamma
.
numel
()
...
...
@@ -17,7 +17,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
rowscale
=
rowscale
.
view
(
-
1
)
if
rowscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1mat
,
gamma
,
beta
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
epsilon
,
1.0
,
0
,
None
,
residual_in_fp32
1.0
,
0
,
None
,
residual_in_fp32
,
is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
...
...
@@ -25,7 +25,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
):
dropout_p
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
...
...
@@ -41,7 +41,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
None
,
None
,
dropout_p
,
1.0
,
0
,
has_residual
dropout_p
,
1.0
,
0
,
has_residual
,
is_rms_norm
)
# dx1mat is None if not has_residual
if
colscale
is
None
:
...
...
@@ -53,7 +53,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
def
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
):
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
"""
hidden_size
=
gamma
.
numel
()
...
...
@@ -63,7 +63,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
out_subset
=
out_subset
.
view
(
-
1
)
if
out_subset
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
x0mat
,
x1mat
,
gamma
,
beta
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
rowscale_const
,
out_numrows
,
None
,
residual_in_fp32
,
is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
...
...
@@ -72,7 +72,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
):
x0_numrows
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
...
...
@@ -89,7 +89,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
assert
x0
is
not
None
,
'x0 is required to compute the gradient of colscale'
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
_
,
_
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_bwd
(
dzmat
,
dxmat
,
xmat
,
x0mat
,
dmask
,
mu
,
rsigma
,
gamma
,
None
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
)
# dx1mat is None if not has_residual
if
colscale
is
None
:
...
...
@@ -101,16 +101,17 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
p
renorm
=
False
,
return_dmask
=
False
):
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
re
sidual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_
norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
beta
=
beta
.
contiguous
()
if
beta
is
not
None
else
None
rowscale
=
rowscale
.
contiguous
()
if
rowscale
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
x0
,
x1
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
...
...
@@ -118,6 +119,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
x1
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
if
not
return_dmask
:
return
(
zmat
.
view
(
x0
.
shape
)
if
not
prenorm
else
(
zmat
.
view
(
x0
.
shape
),
xmat
.
view
(
x0
.
shape
)))
...
...
@@ -138,26 +141,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
dropout_p
=
ctx
.
dropout_p
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
ctx
.
is_rms_norm
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
dx0
,
dx1
,
dgamma
,
dbeta
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
return
(
dx0
,
dx1
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
None
,
dcolscale
,
None
,
None
,
None
,
None
,
None
,
None
)
class
DropoutAddLayerNormSubsetFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
=
False
,
return_dmask
=
False
):
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
beta
=
beta
.
contiguous
()
beta
=
beta
.
contiguous
()
if
beta
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
x1
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
...
...
@@ -169,6 +175,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx
.
rowscale_const
=
rowscale_const
ctx
.
x0_numrows
=
x0
.
shape
[:
-
1
].
numel
()
ctx
.
has_residual
=
x1
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta
is
not
None
z_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
if
not
return_dmask
:
return
(
zmat
.
view
(
z_shape
)
if
not
prenorm
...
...
@@ -191,13 +199,13 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dgamma
,
dbeta
,
*
rest
=
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
ctx
.
rowscale_const
,
ctx
.
x0_numrows
,
has_residual
,
ctx
.
is_rms_norm
)
dx0
=
dx0mat
.
view
(
-
1
,
*
x
.
shape
[
1
:])
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dcolscale
=
rest
[
0
]
if
colscale
is
not
None
else
None
return
(
dx0
,
dx1
,
dgamma
,
dbeta
,
dcolscale
,
None
,
None
,
None
,
Non
e
,
None
,
None
,
None
,
None
,
None
)
return
(
dx0
,
dx1
,
dgamma
,
dbeta
if
ctx
.
has_beta
else
None
,
dcolscal
e
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
...
...
@@ -212,7 +220,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
return_dropout_mask
False
,
return_dropout_mask
)
...
...
@@ -225,7 +233,7 @@ def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, laye
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
return_dropout_mask
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
)
...
...
flash_attn/ops/rms_norm.py
0 → 100644
View file @
6738d947
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import
torch
from
torch.nn
import
init
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormFn
,
DropoutAddLayerNormSubsetFn
def
rms_norm
(
x
,
weight
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
None
,
None
,
None
,
0.0
,
epsilon
,
False
,
False
,
True
)
def
dropout_add_rms_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return
DropoutAddLayerNormFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
rowscale
,
layerscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
)
def
dropout_add_rms_norm_subset
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
layerscale
=
None
,
x0_subset
=
None
,
out_subset
=
None
,
rowscale_const
=
1.0
,
out_numrows
=
0
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return
DropoutAddLayerNormSubsetFn
.
apply
(
x0
,
x1
,
weight
,
bias
,
layerscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
)
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
p
=
p
self
.
epsilon
=
eps
self
.
residual_in_fp32
=
residual_in_fp32
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x0
,
x1
=
None
):
return
dropout_add_rms_norm
(
x0
,
x1
,
self
.
weight
,
None
,
self
.
p
if
self
.
training
else
0.0
,
self
.
epsilon
,
prenorm
=
self
.
prenorm
,
residual_in_fp32
=
self
.
residual_in_fp32
)
tests/ops/test_dropout_layer_norm.py
View file @
6738d947
...
...
@@ -8,11 +8,20 @@ from einops import rearrange, repeat
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNorm
,
dropout_add_layer_norm
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_subset
from
flash_attn.ops.rms_norm
import
DropoutAddRMSNorm
,
dropout_add_rms_norm
from
flash_attn.ops.rms_norm
import
dropout_add_rms_norm_subset
try
:
from
apex.normalization
import
FusedRMSNorm
except
:
FusedRMSNorm
=
None
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'is_rms_norm'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_colscale', [False])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_rowscale', [True])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
...
...
@@ -26,11 +35,17 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
256
])
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
):
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
,
is_rms_norm
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
FusedRMSNorm
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
layer_norm_cls
=
torch
.
nn
.
LayerNorm
if
not
is_rms_norm
else
FusedRMSNorm
our_layer_norm_cls
=
DropoutAddLayerNorm
if
not
is_rms_norm
else
DropoutAddRMSNorm
our_layer_norm_func
=
dropout_add_layer_norm
if
not
is_rms_norm
else
dropout_add_rms_norm
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
...
...
@@ -67,20 +82,22 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
if
has_colscale
:
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
torch
.
nn
.
L
ayer
N
orm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
l
ayer
_n
orm
_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
if
not
is_rms_norm
:
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
layer_norm_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
model
=
our_layer_norm_cls
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
if
not
is_rms_norm
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
dropout_add
_layer_norm
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
out
,
dmask
=
our
_layer_norm
_func
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
assert
out
.
dtype
==
input_dtype
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
if
has_residual
:
...
...
@@ -101,7 +118,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
if
has_residual
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
3e-5
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
3e-5
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
...
...
@@ -151,27 +169,34 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
# @pytest.mark.parametrize('has_colscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('input_dtype,residual_dtype',
# [(torch.float16, torch.float16), (torch.float16, torch.float32),
# (torch.float32, torch.float32)]
# + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
[(
torch
.
float32
,
torch
.
float32
)])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
256
])
@
pytest
.
mark
.
parametrize
(
'is_rms_norm'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'has_colscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_rowscale'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'has_residual'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.37
,
0.0
])
@
pytest
.
mark
.
parametrize
(
'weight_dtype'
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'input_dtype,residual_dtype'
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@
pytest
.
mark
.
parametrize
(
'hidden_size'
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
])
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
):
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
,
is_rms_norm
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
FusedRMSNorm
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
layer_norm_cls
=
torch
.
nn
.
LayerNorm
if
not
is_rms_norm
else
FusedRMSNorm
our_layer_norm_cls
=
DropoutAddLayerNorm
if
not
is_rms_norm
else
DropoutAddRMSNorm
our_layer_norm_func
=
dropout_add_layer_norm
if
not
is_rms_norm
else
dropout_add_rms_norm
device
=
'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
2e-4
)
...
...
@@ -208,23 +233,25 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if
has_colscale
:
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
torch
.
nn
.
L
ayer
N
orm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
model_pt
=
l
ayer
_n
orm
_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
if
not
is_rms_norm
:
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
layer_norm_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
model
=
our_layer_norm_cls
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
if
not
is_rms_norm
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
dropout_add
_layer_norm
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
out
,
residual
,
dmask
=
our
_layer_norm
_func
(
x0
,
x1
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
epsilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
if
has_residual
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
x1_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
...
...
@@ -247,7 +274,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if
has_residual
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
not
is_rms_norm
:
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment