Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c49cd927
Unverified
Commit
c49cd927
authored
Jul 28, 2020
by
Sam Shleifer
Committed by
GitHub
Jul 28, 2020
Browse files
[Fix] position_ids tests again (#6100)
parent
40796c58
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
4 deletions
+5
-4
src/transformers/modeling_bert.py
src/transformers/modeling_bert.py
+1
-2
tests/test_modeling_auto.py
tests/test_modeling_auto.py
+4
-2
No files found.
src/transformers/modeling_bert.py
View file @
c49cd927
...
@@ -568,6 +568,7 @@ class BertPreTrainedModel(PreTrainedModel):
...
@@ -568,6 +568,7 @@ class BertPreTrainedModel(PreTrainedModel):
config_class
=
BertConfig
config_class
=
BertConfig
load_tf_weights
=
load_tf_weights_in_bert
load_tf_weights
=
load_tf_weights_in_bert
base_model_prefix
=
"bert"
base_model_prefix
=
"bert"
authorized_missing_keys
=
[
r
"position_ids"
]
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
""" Initialize the weights """
""" Initialize the weights """
...
@@ -699,8 +700,6 @@ class BertModel(BertPreTrainedModel):
...
@@ -699,8 +700,6 @@ class BertModel(BertPreTrainedModel):
"""
"""
authorized_missing_keys
=
[
r
"position_ids"
]
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
config
=
config
self
.
config
=
config
...
...
tests/test_modeling_auto.py
View file @
c49cd927
...
@@ -88,9 +88,11 @@ class AutoModelTest(unittest.TestCase):
...
@@ -88,9 +88,11 @@ class AutoModelTest(unittest.TestCase):
model
,
loading_info
=
AutoModelForPreTraining
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
model
,
loading_info
=
AutoModelForPreTraining
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
BertForPreTraining
)
self
.
assertIsInstance
(
model
,
BertForPreTraining
)
# Only one value should not be initialized and in the missing keys.
missing_keys
=
loading_info
.
pop
(
"missing_keys"
)
self
.
assertListEqual
([
"cls.predictions.decoder.bias"
],
missing_keys
)
for
key
,
value
in
loading_info
.
items
():
for
key
,
value
in
loading_info
.
items
():
# Only one value should not be initialized and in the missing keys.
self
.
assertEqual
(
len
(
value
),
0
)
self
.
assertEqual
(
len
(
value
),
1
if
key
==
"missing_keys"
else
0
)
@
slow
@
slow
def
test_lmhead_model_from_pretrained
(
self
):
def
test_lmhead_model_from_pretrained
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment