Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
27d55125
Unverified
Commit
27d55125
authored
May 01, 2020
by
Julien Chaumond
Committed by
GitHub
May 01, 2020
Browse files
Configs: saner num_labels in configs. (#3967)
parent
e80be7f1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
10 deletions
+9
-10
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+9
-10
No files found.
src/transformers/configuration_utils.py
View file @
27d55125
...
...
@@ -86,11 +86,13 @@ class PretrainedConfig(object):
# Fine-tuning task arguments
self
.
architectures
=
kwargs
.
pop
(
"architectures"
,
None
)
self
.
finetuning_task
=
kwargs
.
pop
(
"finetuning_task"
,
None
)
self
.
num_labels
=
kwargs
.
pop
(
"num_labels"
,
2
)
self
.
id2label
=
kwargs
.
pop
(
"id2label"
,
{
i
:
f
"LABEL_
{
i
}
"
for
i
in
range
(
self
.
num_labels
)})
self
.
id2label
=
kwargs
.
pop
(
"id2label"
,
None
)
self
.
label2id
=
kwargs
.
pop
(
"label2id"
,
None
)
if
self
.
id2label
is
not
None
:
self
.
id2label
=
dict
((
int
(
key
),
value
)
for
key
,
value
in
self
.
id2label
.
items
())
self
.
label2id
=
kwargs
.
pop
(
"label2id"
,
dict
(
zip
(
self
.
id2label
.
values
(),
self
.
id2label
.
keys
())))
self
.
label2id
=
dict
((
key
,
int
(
value
))
for
key
,
value
in
self
.
label2id
.
items
())
# Keys are always strings in JSON so convert ids to int here.
else
:
self
.
num_labels
=
kwargs
.
pop
(
"num_labels"
,
2
)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self
.
prefix
=
kwargs
.
pop
(
"prefix"
,
None
)
...
...
@@ -115,15 +117,12 @@ class PretrainedConfig(object):
@
property
def
num_labels
(
self
):
return
self
.
_num_
label
s
return
len
(
self
.
id2
label
)
@
num_labels
.
setter
def
num_labels
(
self
,
num_labels
):
self
.
_num_labels
=
num_labels
self
.
id2label
=
{
i
:
"LABEL_{}"
.
format
(
i
)
for
i
in
range
(
self
.
num_labels
)}
self
.
id2label
=
dict
((
int
(
key
),
value
)
for
key
,
value
in
self
.
id2label
.
items
())
self
.
id2label
=
{
i
:
"LABEL_{}"
.
format
(
i
)
for
i
in
range
(
num_labels
)}
self
.
label2id
=
dict
(
zip
(
self
.
id2label
.
values
(),
self
.
id2label
.
keys
()))
self
.
label2id
=
dict
((
key
,
int
(
value
))
for
key
,
value
in
self
.
label2id
.
items
())
def
save_pretrained
(
self
,
save_directory
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment