Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
393882bc
Commit
393882bc
authored
Mar 29, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement LN with parallel residual, support dim 8k
parent
009a3e71
Changes
46
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
499 additions
and
67 deletions
+499
-67
flash_attn/modules/block.py
flash_attn/modules/block.py
+30
-15
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+109
-2
flash_attn/ops/rms_norm.py
flash_attn/ops/rms_norm.py
+14
-0
tests/models/test_gpt_neox.py
tests/models/test_gpt_neox.py
+1
-1
tests/models/test_gptj.py
tests/models/test_gptj.py
+1
-1
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+344
-48
No files found.
flash_attn/modules/block.py
View file @
393882bc
...
@@ -18,6 +18,11 @@ try:
...
@@ -18,6 +18,11 @@ try:
except
ImportError
:
except
ImportError
:
dropout_add_layer_norm
=
None
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
...
@@ -64,7 +69,7 @@ class Block(nn.Module):
...
@@ -64,7 +69,7 @@ class Block(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
'dropout_
add_ln
is not installed'
assert
dropout_add_layer_norm
is
not
None
,
'dropout_
layer_norm
is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
...
@@ -214,7 +219,6 @@ class ParallelBlock(nn.Module):
...
@@ -214,7 +219,6 @@ class ParallelBlock(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
tied_norm
=
tied_norm
self
.
tied_norm
=
tied_norm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
assert
not
self
.
fused_dropout_add_ln
,
'This is not implemented for ParallelBlock yet'
self
.
residual_in_fp32
=
residual_in_fp32
self
.
residual_in_fp32
=
residual_in_fp32
if
mixer_cls
is
None
:
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
...
@@ -229,7 +233,7 @@ class ParallelBlock(nn.Module):
...
@@ -229,7 +233,7 @@ class ParallelBlock(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
'dropout_
add_ln
is not installed'
assert
dropout_add_layer_norm
_parallel_residual
is
not
None
,
'dropout_
layer_norm
is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
...
@@ -262,19 +266,30 @@ class ParallelBlock(nn.Module):
...
@@ -262,19 +266,30 @@ class ParallelBlock(nn.Module):
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
residual.
"""
"""
dropped1
=
self
.
dropout1
(
hidden_states1
)
if
not
self
.
fused_dropout_add_ln
:
# For the very 1st block, we only want 1 dropout, not two different dropouts
dropped1
=
self
.
dropout1
(
hidden_states1
)
if
hidden_states2
is
not
None
:
# For the very 1st block, we only want 1 dropout, not two different dropouts
dropped2
=
self
.
dropout2
(
hidden_states2
)
if
hidden_states2
is
not
None
:
residual
=
((
residual
+
dropped1
+
dropped2
)
dropped2
=
self
.
dropout2
(
hidden_states2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
residual
=
((
residual
+
dropped1
+
dropped2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
hidden_states1
)
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
weight2
,
bias2
=
((
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
(
None
,
None
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
hidden_states1
,
hidden_states2
,
residual
=
dropout_add_layer_norm_parallel_residual
(
if
not
self
.
tied_norm
else
hidden_states1
)
hidden_states1
,
hidden_states2
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
if
self
.
residual_in_fp32
:
weight2
,
bias2
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
residual
=
residual
.
to
(
torch
.
float32
)
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
)
if
self
.
tied_norm
:
hidden_states2
=
hidden_states1
if
mixer_kwargs
is
None
:
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
mixer_kwargs
=
{}
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
...
...
flash_attn/ops/layer_norm.py
View file @
393882bc
...
@@ -99,6 +99,46 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
...
@@ -99,6 +99,46 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
return
dx0mat
,
dresidualmat
,
dgamma
,
dbeta
,
dcolscale
def
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
"""
hidden_size
=
gamma0
.
numel
()
x0mat
=
x0
.
view
((
-
1
,
hidden_size
))
x1mat
=
x1
.
view
((
-
1
,
hidden_size
))
if
x1
is
not
None
else
None
residualmat
=
residual
.
view
((
-
1
,
hidden_size
))
if
residual
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_fwd
(
x0mat
,
x1mat
,
residualmat
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
None
,
residual_in_fp32
,
is_rms_norm
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return
z0mat
,
z1mat
,
xmat
if
xmat
is
not
None
else
x0mat
,
dmask0
,
dmask1
,
mu
,
rsigma
def
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
=
False
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
hidden_size
=
gamma0
.
numel
()
xmat
=
x
.
view
((
-
1
,
hidden_size
))
dz0mat
=
dz0
.
view
(
xmat
.
shape
)
dz1mat
=
dz1
.
view
(
xmat
.
shape
)
if
dz1
is
not
None
else
None
dxmat
=
dx
.
view
(
xmat
.
shape
)
if
dx
is
not
None
else
None
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_parallel_residual_bwd
(
dz0mat
,
dz1mat
,
dxmat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
is_rms_norm
)
# dresidualmat is None if not has_residual
return
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
class
DropoutAddLayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
def
forward
(
ctx
,
x0
,
residual
,
gamma
,
beta
,
rowscale
,
colscale
,
dropout_p
,
epsilon
,
...
@@ -115,7 +155,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
...
@@ -115,7 +155,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
)
)
# Only need to save x0 if we need to compute gradient wrt colscale
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x0_saved
=
x0
if
colscale
is
not
None
else
None
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
x0
_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
rowscale
,
colscale
)
ctx
.
prenorm
=
prenorm
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
has_residual
=
residual
is
not
None
ctx
.
has_residual
=
residual
is
not
None
...
@@ -168,7 +208,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
...
@@ -168,7 +208,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
# Only need to save x0 if we need to compute gradient wrt colscale
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved
=
x0
if
colscale
is
not
None
else
None
x0_saved
=
x0
if
colscale
is
not
None
else
None
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
x_shape
=
(
-
1
,
*
x0
.
shape
[
1
:])
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
ctx
.
save_for_backward
(
xmat
.
view
(
x_shape
),
x0
_saved
,
dmask
,
gamma
,
mu
,
rsigma
,
colscale
,
x0_subset
,
out_subset
)
x0_subset
,
out_subset
)
ctx
.
prenorm
=
prenorm
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -208,6 +248,60 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
...
@@ -208,6 +248,60 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
class
DropoutAddLayerNormParallelResidualFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
False
,
prenorm
=
False
,
is_rms_norm
=
False
,
return_dmask
=
False
):
x0
=
x0
.
contiguous
()
x1
=
x1
.
contiguous
()
if
x1
is
not
None
else
None
residual
=
residual
.
contiguous
()
if
residual
is
not
None
else
None
gamma0
=
gamma0
.
contiguous
()
beta0
=
beta0
.
contiguous
()
if
beta0
is
not
None
else
None
gamma1
=
gamma1
.
contiguous
()
if
gamma1
is
not
None
else
None
beta1
=
beta1
.
contiguous
()
if
beta1
is
not
None
else
None
z0mat
,
z1mat
,
xmat
,
dmask0
,
dmask1
,
mu
,
rsigma
=
_dropout_add_layer_norm_parallel_residual_forward
(
x0
,
x1
,
residual
,
gamma0
,
beta0
,
gamma1
,
beta1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
is_rms_norm
)
ctx
.
save_for_backward
(
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
)
ctx
.
prenorm
=
prenorm
ctx
.
dropout_p
=
dropout_p
ctx
.
has_x1
=
x1
is
not
None
ctx
.
has_residual
=
residual
is
not
None
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
has_beta
=
beta0
is
not
None
z
=
(
z0mat
.
view
(
x0
.
shape
),
z1mat
.
view
(
x0
.
shape
)
if
z1mat
is
not
None
else
None
)
if
not
return_dmask
:
return
z
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
))
else
:
dmask0
=
(
dmask0
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
dmask1
=
(
dmask1
.
view
(
x0
.
shape
)
if
dropout_p
>
0.
and
x1
is
not
None
else
torch
.
ones
(
x0
.
shape
,
dtype
=
torch
.
uint8
,
device
=
x0
.
device
))
ctx
.
mark_non_differentiable
(
dmask0
)
ctx
.
mark_non_differentiable
(
dmask1
)
return
(
*
z
,
dmask0
,
dmask1
)
if
not
prenorm
else
(
*
z
,
xmat
.
view
(
x0
.
shape
),
dmask0
,
dmask1
)
@
staticmethod
def
backward
(
ctx
,
dz0
,
dz1
,
*
args
):
dz0
=
dz0
.
contiguous
()
# this happens!
dz1
=
dz1
.
contiguous
()
if
dz1
is
not
None
else
None
dx
=
args
[
0
].
contiguous
()
if
ctx
.
prenorm
else
None
x
,
dmask0
,
dmask1
,
gamma0
,
gamma1
,
mu
,
rsigma
=
ctx
.
saved_tensors
dropout_p
=
ctx
.
dropout_p
has_x1
=
ctx
.
has_x1
has_residual
=
ctx
.
has_residual
dx0mat
,
dx1mat
,
dresidualmat
,
dgamma0
,
dbeta0
,
dgamma1
,
dbeta1
=
_dropout_add_layer_norm_parallel_residual_backward
(
dz0
,
dz1
,
dx
,
x
,
dmask0
,
dmask1
,
mu
,
rsigma
,
gamma0
,
gamma1
,
dropout_p
,
has_x1
,
has_residual
,
ctx
.
is_rms_norm
)
dx0
=
dx0mat
.
view
(
x
.
shape
)
dx1
=
dx1mat
.
view
(
x
.
shape
)
if
dx1mat
is
not
None
else
None
dresidual
=
dresidualmat
.
view
(
x
.
shape
)
if
dresidualmat
is
not
None
else
None
return
(
dx0
,
dx1
,
dresidual
,
dgamma0
,
dbeta0
if
ctx
.
has_beta
else
None
,
dgamma1
,
dbeta1
if
ctx
.
has_beta
else
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
...
@@ -237,6 +331,19 @@ def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon
...
@@ -237,6 +331,19 @@ def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon
)
)
def
dropout_add_layer_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
False
,
return_dropout_mask
)
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
class
DropoutAddLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
device
=
None
,
dtype
=
None
):
...
...
flash_attn/ops/rms_norm.py
View file @
393882bc
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
torch.nn
import
init
from
torch.nn
import
init
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormFn
,
DropoutAddLayerNormSubsetFn
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormFn
,
DropoutAddLayerNormSubsetFn
from
flash_attn.ops.layer_norm
import
DropoutAddLayerNormParallelResidualFn
def
rms_norm
(
x
,
weight
,
epsilon
):
def
rms_norm
(
x
,
weight
,
epsilon
):
...
@@ -37,6 +38,19 @@ def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon,
...
@@ -37,6 +38,19 @@ def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon,
)
)
def
dropout_add_rms_norm_parallel_residual
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return
DropoutAddLayerNormParallelResidualFn
.
apply
(
x0
,
x1
,
residual
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
,
prenorm
,
True
,
return_dropout_mask
)
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
class
DropoutAddRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
def
__init__
(
self
,
hidden_size
,
prenorm
=
False
,
p
=
0.0
,
eps
=
1e-5
,
residual_in_fp32
=
False
,
device
=
None
,
dtype
=
None
):
device
=
None
,
dtype
=
None
):
...
...
tests/models/test_gpt_neox.py
View file @
393882bc
...
@@ -35,7 +35,7 @@ def test_gpt_neox_optimized(model_name):
...
@@ -35,7 +35,7 @@ def test_gpt_neox_optimized(model_name):
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
# GPT-NeoX-20B uses "gelu_fast"
config
.
fused_mlp
=
True
# GPT-NeoX-20B uses "gelu_fast"
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
...
...
tests/models/test_gptj.py
View file @
393882bc
...
@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
...
@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
config
.
use_flash_attn
=
False
# FlashAttention doesn't support hdim 256 yet
config
.
use_flash_attn
=
False
# FlashAttention doesn't support hdim 256 yet
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
...
...
tests/ops/test_dropout_layer_norm.py
View file @
393882bc
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment