Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3b7fb48c
Commit
3b7fb48c
authored
Sep 25, 2019
by
thomwolf
Browse files
fix loading from tf/pt
parent
a049c804
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
4 deletions
+5
-4
examples/run_tf_glue.py
examples/run_tf_glue.py
+2
-1
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+2
-2
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+1
-1
No files found.
examples/run_tf_glue.py
View file @
3b7fb48c
...
@@ -27,7 +27,8 @@ tf_model.save_pretrained('./runs/')
...
@@ -27,7 +27,8 @@ tf_model.save_pretrained('./runs/')
pt_model
=
BertForSequenceClassification
.
from_pretrained
(
'./runs/'
)
pt_model
=
BertForSequenceClassification
.
from_pretrained
(
'./runs/'
)
# Quickly inspect a few predictions
# Quickly inspect a few predictions
inputs
=
tokenizer
.
encode_plus
(
"I said the company is doing great"
,
"The company has good results"
,
add_special_tokens
=
True
)
pred
=
pt_model
(
torch
.
tensor
([
tokens
]))
# Divers
# Divers
import
torch
import
torch
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
3b7fb48c
...
@@ -224,8 +224,8 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -224,8 +224,8 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
else
:
raise
EnvironmentError
(
"Error no file named {} found in directory {}"
.
format
(
raise
EnvironmentError
(
"Error no file named {} found in directory {}
or `from_pt` set to False
"
.
format
(
tuple
(
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
)
,
[
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
]
,
pretrained_model_name_or_path
))
pretrained_model_name_or_path
))
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
archive_file
=
pretrained_model_name_or_path
...
...
pytorch_transformers/modeling_utils.py
View file @
3b7fb48c
...
@@ -304,7 +304,7 @@ class PreTrainedModel(nn.Module):
...
@@ -304,7 +304,7 @@ class PreTrainedModel(nn.Module):
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
else
:
raise
EnvironmentError
(
"Error no file named {} found in directory {}"
.
format
(
raise
EnvironmentError
(
"Error no file named {} found in directory {}
or `from_tf` set to False
"
.
format
(
[
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
+
".index"
],
[
WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
TF_WEIGHTS_NAME
+
".index"
],
pretrained_model_name_or_path
))
pretrained_model_name_or_path
))
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment