Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3b0a14b7
"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "0095e01874473129481bae2d4135c91e3fecdf87"
Commit
3b0a14b7
authored
Dec 12, 2018
by
Deyu Fu
Browse files
add fallback path for apex used in modeling.py
parent
c8ea2860
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
28 deletions
+23
-28
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+23
-28
No files found.
pytorch_pretrained_bert/modeling.py
View file @
3b0a14b7
...
@@ -31,10 +31,6 @@ import shutil
...
@@ -31,10 +31,6 @@ import shutil
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
try
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this."
)
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
...
@@ -157,22 +153,24 @@ class BertConfig(object):
...
@@ -157,22 +153,24 @@ class BertConfig(object):
"""Serializes this instance to a JSON string."""
"""Serializes this instance to a JSON string."""
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
try
:
class
BertLayerNorm
(
nn
.
Module
):
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
BertLayerNorm
def
__init__
(
self
,
config
,
variance_epsilon
=
1e-12
):
except
ImportError
:
"""Construct a layernorm module in the TF style (epsilon inside the square root).
print
(
"Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex."
)
"""
class
BertLayerNorm
(
nn
.
Module
):
super
(
BertLayerNorm
,
self
).
__init__
()
def
__init__
(
self
,
hidden_size
,
eps
=
1e-12
):
self
.
gamma
=
nn
.
Parameter
(
torch
.
ones
(
config
.
hidden_size
))
"""Construct a layernorm module in the TF style (epsilon inside the square root).
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
config
.
hidden_size
))
"""
self
.
variance_epsilon
=
variance_epsilon
super
(
BertLayerNorm
,
self
).
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
def
forward
(
self
,
x
):
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
u
=
x
.
mean
(
-
1
,
keepdim
=
True
)
self
.
variance_epsilon
=
eps
s
=
(
x
-
u
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
variance_epsilon
)
def
forward
(
self
,
x
):
return
self
.
gamma
*
x
+
self
.
beta
u
=
x
.
mean
(
-
1
,
keepdim
=
True
)
s
=
(
x
-
u
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
variance_epsilon
)
return
self
.
weight
*
x
+
self
.
bias
class
BertEmbeddings
(
nn
.
Module
):
class
BertEmbeddings
(
nn
.
Module
):
"""Construct the embeddings from word, position and token_type embeddings.
"""Construct the embeddings from word, position and token_type embeddings.
...
@@ -185,7 +183,7 @@ class BertEmbeddings(nn.Module):
...
@@ -185,7 +183,7 @@ class BertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
# any TensorFlow checkpoint file
self
.
LayerNorm
=
Fused
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
Bert
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
...
@@ -260,7 +258,7 @@ class BertSelfOutput(nn.Module):
...
@@ -260,7 +258,7 @@ class BertSelfOutput(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BertSelfOutput
,
self
).
__init__
()
super
(
BertSelfOutput
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
Fused
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
Bert
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
hidden_states
,
input_tensor
):
def
forward
(
self
,
hidden_states
,
input_tensor
):
...
@@ -299,7 +297,7 @@ class BertOutput(nn.Module):
...
@@ -299,7 +297,7 @@ class BertOutput(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BertOutput
,
self
).
__init__
()
super
(
BertOutput
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
self
.
LayerNorm
=
Fused
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
Bert
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
hidden_states
,
input_tensor
):
def
forward
(
self
,
hidden_states
,
input_tensor
):
...
@@ -361,7 +359,7 @@ class BertPredictionHeadTransform(nn.Module):
...
@@ -361,7 +359,7 @@ class BertPredictionHeadTransform(nn.Module):
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
\
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
\
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
self
.
LayerNorm
=
Fused
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
self
.
LayerNorm
=
Bert
LayerNorm
(
config
.
hidden_size
,
eps
=
1e-12
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dense
(
hidden_states
)
...
@@ -443,12 +441,9 @@ class PreTrainedBertModel(nn.Module):
...
@@ -443,12 +441,9 @@ class PreTrainedBertModel(nn.Module):
# Slightly different from the TF version which uses truncated_normal for initialization
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
elif
isinstance
(
module
,
Fused
LayerNorm
):
elif
isinstance
(
module
,
Bert
LayerNorm
):
module
.
bias
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
module
.
bias
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
elif
isinstance
(
module
,
BertLayerNorm
):
module
.
beta
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
module
.
gamma
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment