Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d2f4324f
Commit
d2f4324f
authored
Jul 04, 2023
by
Tri Dao
Browse files
[LayerNorm] Make sure memory addresses are aligned to 16 bytes
parent
3a9bfd07
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
35 deletions
+43
-35
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+39
-31
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+4
-4
No files found.
flash_attn/ops/layer_norm.py
View file @
d2f4324f
...
@@ -7,9 +7,17 @@ from torch.nn import init
...
@@ -7,9 +7,17 @@ from torch.nn import init
import
dropout_layer_norm
import
dropout_layer_norm
def
maybe_align
(
x
,
alignment_in_bytes
=
16
):
"""Assume that x already has last dim divisible by alignment_in_bytes
"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return
x
if
x
.
data_ptr
()
%
alignment_in_bytes
==
0
else
x
.
clone
()
def
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
def
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
and aligned to 16 bytes
"""
"""
hidden_size
=
gamma
.
numel
()
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
...
@@ -26,7 +34,7 @@ def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscal
...
@@ -26,7 +34,7 @@ def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscal
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
def
_dropout_add_layer_norm_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
rowscale
,
colscale
,
dropout_p
,
has_residual
,
is_rms_norm
=
False
):
dropout_p
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
and aligned to 16 bytes
dx == None means that it was a post-norm architecture
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
x0 must not be None if we have colscale.
...
@@ -54,7 +62,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
...
@@ -54,7 +62,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
def
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
def
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
out_numrows
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
and aligned to 16 bytes
"""
"""
hidden_size
=
gamma
.
numel
()
hidden_size
=
gamma
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
...
@@ -73,7 +81,7 @@ def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale,
...
@@ -73,7 +81,7 @@ def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale,
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
def
_dropout_add_layer_norm_subset_backward
(
dz
,
dx
,
x
,
x0
,
dmask
,
mu
,
rsigma
,
gamma
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_subset
,
out_subset
,
dropout_p
,
rowscale_const
,
x0_numrows
,
has_residual
,
is_rms_norm
=
False
):
x0_numrows
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
and aligned to 16 bytes
dx == None means that it was a post-norm architecture
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
x0 must not be None if we have colscale.
...
@@ -103,7 +111,7 @@ def _dropout_add_layer_norm_parallel_residual_forward(
...
@@ -103,7 +111,7 @@ def _dropout_add_layer_norm_parallel_residual_forward(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
and aligned to 16 bytes
"""
"""
hidden_size
=
gamma0
.
numel
()
hidden_size
=
gamma0
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
...
@@ -122,7 +130,7 @@ def _dropout_add_layer_norm_parallel_residual_backward(
...
@@ -122,7 +130,7 @@ def _dropout_add_layer_norm_parallel_residual_backward(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
):
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous
and aligned to 16 bytes
dx == None means that it was a post-norm architecture
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
(x = drop(x0) + residual was not returned in the fwd).
"""
"""
...
@@ -143,12 +151,12 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
...
@@ -143,12 +151,12 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x0
=
maybe_align
(
x0
.
contiguous
()
,
16
)
residual
=
residual
.
contiguous
()
if
residual
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
()
,
16
)
if
residual
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
gamma
=
maybe_align
(
gamma
.
contiguous
()
,
16
)
beta
=
beta
.
contiguous
()
if
beta
is
not
None
else
None
beta
=
maybe_align
(
beta
.
contiguous
()
,
16
)
if
beta
is
not
None
else
None
rowscale
=
rowscale
.
contiguous
()
if
rowscale
is
not
None
else
None
rowscale
=
maybe_align
(
rowscale
.
contiguous
()
,
16
)
if
rowscale
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
()
,
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_forward
(
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
residual_in_fp32
,
is_rms_norm
...
@@ -174,8 +182,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
...
@@ -174,8 +182,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
def
backward
(
ctx
,
dz
,
*
args
):
# assert dz.is_contiguous()
# assert dz.is_contiguous()
dz
=
dz
.
contiguous
()
# this happens!
dz
=
maybe_align
(
dz
.
contiguous
()
,
16
)
# this happens!
dx
=
args
[
0
].
contiguous
()
if
ctx
.
prenorm
else
None
dx
=
maybe_align
(
args
[
0
].
contiguous
()
,
16
)
if
ctx
.
prenorm
else
None
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
=
ctx
.
saved_tensors
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
=
ctx
.
saved_tensors
# x0 is None if colscale is None
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
dropout_p
=
ctx
.
dropout_p
...
@@ -196,11 +204,11 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
...
@@ -196,11 +204,11 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
rowscale_const
,
out_numrows
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x0
=
maybe_align
(
x0
.
contiguous
()
,
16
)
residual
=
residual
.
contiguous
()
if
residual
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
()
,
16
)
if
residual
is
not
None
else
None
gamma
=
gamma
.
contiguous
()
gamma
=
maybe_align
(
gamma
.
contiguous
()
,
16
)
beta
=
beta
.
contiguous
()
if
beta
is
not
None
else
None
beta
=
maybe_align
(
beta
.
contiguous
()
,
16
)
if
beta
is
not
None
else
None
colscale
=
colscale
.
contiguous
()
if
colscale
is
not
None
else
None
colscale
=
maybe_align
(
colscale
.
contiguous
()
,
16
)
if
colscale
is
not
None
else
None
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
zmat
,
xmat
,
dmask
,
mu
,
rsigma
=
_dropout_add_layer_norm_subset_forward
(
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
x0
,
residual
,
gamma
,
beta
,
colscale
,
x0_subset
,
out_subset
,
dropout_p
,
epsilon
,
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
rowscale_const
,
out_numrows
,
residual_in_fp32
,
is_rms_norm
...
@@ -231,8 +239,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
...
@@ -231,8 +239,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dz
,
*
args
):
def
backward
(
ctx
,
dz
,
*
args
):
# assert dz.is_contiguous()
# assert dz.is_contiguous()
dz
=
dz
.
contiguous
()
# this happens!
dz
=
maybe_align
(
dz
.
contiguous
()
,
16
)
# this happens!
dx
=
args
[
0
].
contiguous
()
if
ctx
.
prenorm
else
None
dx
=
maybe_align
(
args
[
0
].
contiguous
()
,
16
)
if
ctx
.
prenorm
else
None
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
=
ctx
.
saved_tensors
x
,
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
=
ctx
.
saved_tensors
# x0 is None if colscale is None
# x0 is None if colscale is None
dropout_p
=
ctx
.
dropout_p
dropout_p
=
ctx
.
dropout_p
...
@@ -252,13 +260,13 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
...
@@ -252,13 +260,13 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x0
=
maybe_align
(
x0
.
contiguous
()
,
16
)
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
x1
=
maybe_align
(
x1
.
contiguous
()
,
16
)
if
x1
is
not
None
else
None
residual
=
residual
.
contiguous
()
if
residual
is
not
None
else
None
residual
=
maybe_align
(
residual
.
contiguous
()
,
16
)
if
residual
is
not
None
else
None
gamma0
=
gamma0
.
contiguous
()
gamma0
=
maybe_align
(
gamma0
.
contiguous
()
,
16
)
beta0
=
beta0
.
contiguous
()
if
beta0
is
not
None
else
None
beta0
=
maybe_align
(
beta0
.
contiguous
()
,
16
)
if
beta0
is
not
None
else
None
gamma1
=
gamma1
.
contiguous
()
if
gamma1
is
not
None
else
None
gamma1
=
maybe_align
(
gamma1
.
contiguous
()
,
16
)
if
gamma1
is
not
None
else
None
beta1
=
beta1
.
contiguous
()
if
beta1
is
not
None
else
None
beta1
=
maybe_align
(
beta1
.
contiguous
()
,
16
)
if
beta1
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
_dropout_add_layer_norm_parallel_residual_forward
(
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
residual_in_fp32
,
is_rms_norm
...
@@ -284,9 +292,9 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
...
@@ -284,9 +292,9 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
dz0
=
dz0
.
contiguous
()
# this happens!
dz0
=
maybe_align
(
dz0
.
contiguous
()
,
16
)
# this happens!
dz1
=
dz1
.
contiguous
()
if
dz1
is
not
None
else
None
dz1
=
maybe_align
(
dz1
.
contiguous
()
,
16
)
if
dz1
is
not
None
else
None
dx
=
args
[
0
].
contiguous
()
if
ctx
.
prenorm
else
None
dx
=
maybe_align
(
args
[
0
].
contiguous
()
,
16
)
if
ctx
.
prenorm
else
None
x
,
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
=
ctx
.
saved_tensors
x
,
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
=
ctx
.
saved_tensors
dropout_p
=
ctx
.
dropout_p
dropout_p
=
ctx
.
dropout_p
has_x1
=
ctx
.
has_x1
has_x1
=
ctx
.
has_x1
...
...
tests/ops/test_dropout_layer_norm.py
View file @
d2f4324f
...
@@ -99,7 +99,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
...
@@ -99,7 +99,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
our_layer_norm_func
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
out
,
dmask
=
our_layer_norm_func
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
ilon
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
model
.
eps
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
assert
out
.
dtype
==
input_dtype
assert
out
.
dtype
==
input_dtype
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
print
(
f
'Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
'
)
...
@@ -251,7 +251,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
...
@@ -251,7 +251,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
our_layer_norm_func
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
out
,
residual
,
dmask
=
our_layer_norm_func
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
ilon
,
rowscale
=
rowscale
,
model
.
eps
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
prenorm
=
True
,
layerscale
=
colscale
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
return_dropout_mask
=
True
)
...
@@ -412,7 +412,7 @@ def test_dropout_layer_norm_subset_training(
...
@@ -412,7 +412,7 @@ def test_dropout_layer_norm_subset_training(
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
dropout_add_layer_norm_subset
(
out
,
dmask
=
dropout_add_layer_norm_subset
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
ilon
,
layerscale
=
colscale
,
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
layerscale
=
colscale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
out_numrows
=
out_numrows
,
prenorm
=
False
,
residual_in_fp32
=
residual_in_fp32
,
out_numrows
=
out_numrows
,
prenorm
=
False
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
return_dropout_mask
=
True
)
...
@@ -532,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training(
...
@@ -532,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training(
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
dropout_add_layer_norm_subset
(
out
,
residual
,
dmask
=
dropout_add_layer_norm_subset
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
ilon
,
layerscale
=
colscale
,
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
layerscale
=
colscale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
out_numrows
=
out_numrows
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
out_numrows
=
out_numrows
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
)
return_dropout_mask
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment