Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7ba83730
"pytorch_transformers/optimization.py" did not exist on "886cb49792f0d39b24d285726e9434897dd9dc6e"
Commit
7ba83730
authored
Nov 13, 2018
by
lukovnikov
Browse files
clean up pr
parent
fa0c5a2e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
5 deletions
+15
-5
convert_tf_checkpoint_to_pytorch.py
convert_tf_checkpoint_to_pytorch.py
+9
-3
modeling.py
modeling.py
+6
-2
No files found.
convert_tf_checkpoint_to_pytorch.py
View file @
7ba83730
...
@@ -68,11 +68,17 @@ def convert():
...
@@ -68,11 +68,17 @@ def convert():
arrays
.
append
(
array
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
[
5
:]
# skip "bert/"
if
not
name
.
startswith
(
"bert"
):
print
(
"Skipping {}"
.
format
(
name
))
continue
else
:
name
=
name
.
replace
(
"bert/"
,
""
)
# skip "bert/"
print
(
"Loading {}"
.
format
(
name
))
print
(
"Loading {}"
.
format
(
name
))
name
=
name
.
split
(
'/'
)
name
=
name
.
split
(
'/'
)
if
name
[
0
]
in
[
'redictions'
,
'eq_relationship'
]:
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
print
(
"Skipping"
)
# which are not required for using pretrained model
if
name
[
0
]
in
[
'redictions'
,
'eq_relationship'
]
or
name
[
-
1
]
==
"adam_v"
or
name
[
-
1
]
==
"adam_m"
:
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
continue
pointer
=
model
pointer
=
model
for
m_name
in
name
:
for
m_name
in
name
:
...
...
modeling.py
View file @
7ba83730
...
@@ -26,6 +26,10 @@ import torch
...
@@ -26,6 +26,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
ReLU
,
"swish"
:
swish
}
def
gelu
(
x
):
def
gelu
(
x
):
"""Implementation of the gelu activation function.
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
...
@@ -241,8 +245,8 @@ class BERTIntermediate(nn.Module):
...
@@ -241,8 +245,8 @@ class BERTIntermediate(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
BERTIntermediate
,
self
).
__init__
()
super
(
BERTIntermediate
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
act
2
fn
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
ReLU
,
"swish"
:
swish
}
self
.
intermediate_
act
_
fn
=
ACT2FN
[
config
.
hidden_act
]
\
self
.
intermediate_act_fn
=
act2fn
[
config
.
hidden_act
]
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dense
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment