Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ffa17fe3
Unverified
Commit
ffa17fe3
authored
Mar 25, 2020
by
Patrick von Platen
Committed by
GitHub
Mar 25, 2020
Browse files
Extend config with task specific configs. (#3433)
* add new default configs * change prefix default to None
parent
83272a38
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
13 deletions
+27
-13
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+10
-3
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+8
-5
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+9
-5
No files found.
src/transformers/configuration_utils.py
View file @
ffa17fe3
...
...
@@ -78,9 +78,6 @@ class PretrainedConfig(object):
self
.
top_k
=
kwargs
.
pop
(
"top_k"
,
50
)
self
.
top_p
=
kwargs
.
pop
(
"top_p"
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
"repetition_penalty"
,
1.0
)
self
.
bos_token_id
=
kwargs
.
pop
(
"bos_token_id"
,
None
)
self
.
pad_token_id
=
kwargs
.
pop
(
"pad_token_id"
,
None
)
self
.
eos_token_id
=
kwargs
.
pop
(
"eos_token_id"
,
None
)
self
.
length_penalty
=
kwargs
.
pop
(
"length_penalty"
,
1.0
)
self
.
no_repeat_ngram_size
=
kwargs
.
pop
(
"no_repeat_ngram_size"
,
0
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
...
...
@@ -94,6 +91,16 @@ class PretrainedConfig(object):
self
.
label2id
=
kwargs
.
pop
(
"label2id"
,
dict
(
zip
(
self
.
id2label
.
values
(),
self
.
id2label
.
keys
())))
self
.
label2id
=
dict
((
key
,
int
(
value
))
for
key
,
value
in
self
.
label2id
.
items
())
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self
.
prefix
=
kwargs
.
pop
(
"prefix"
,
None
)
self
.
bos_token_id
=
kwargs
.
pop
(
"bos_token_id"
,
None
)
self
.
pad_token_id
=
kwargs
.
pop
(
"pad_token_id"
,
None
)
self
.
eos_token_id
=
kwargs
.
pop
(
"eos_token_id"
,
None
)
self
.
decoder_start_token_id
=
kwargs
.
pop
(
"decoder_start_token_id"
,
None
)
# task specific arguments
self
.
task_specific_params
=
kwargs
.
pop
(
"task_specific_params"
,
None
)
# Additional attributes without default values
for
key
,
value
in
kwargs
.
items
():
try
:
...
...
src/transformers/modeling_tf_utils.py
View file @
ffa17fe3
...
...
@@ -610,7 +610,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
num_return_sequences
=
(
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
)
decoder_start_token_id
=
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
bos_token_id
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
decoder_start_token_id
)
if
input_ids
is
not
None
:
batch_size
=
shape_list
(
input_ids
)[
0
]
# overriden by the input batch_size
...
...
@@ -635,9 +637,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert
(
eos_token_id
is
None
)
or
(
isinstance
(
eos_token_id
,
int
)
and
(
eos_token_id
>=
0
)
),
"`eos_token_id` should be a positive integer."
assert
(
decoder_start_token_id
is
not
None
or
self
.
config
.
is_encoder_decoder
is
False
),
"`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert
length_penalty
>
0
,
"`length_penalty` should be strictely positive."
assert
(
isinstance
(
num_return_sequences
,
int
)
and
num_return_sequences
>
0
...
...
@@ -708,8 +707,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
if
self
.
config
.
is_encoder_decoder
:
if
decoder_start_token_id
is
None
:
decoder_start_token_id
=
bos_token_id
assert
bos_token_id
is
not
None
,
"Encoder Decoder Models need to have a bos_token_id"
assert
(
decoder_start_token_id
is
not
None
),
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert
hasattr
(
self
,
"get_encoder"
),
"{} should have a 'get_encoder' function defined"
.
format
(
self
)
assert
callable
(
self
.
get_encoder
),
"{} should be a method"
.
format
(
self
.
get_encoder
)
...
...
src/transformers/modeling_utils.py
View file @
ffa17fe3
...
...
@@ -809,7 +809,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_return_sequences
=
(
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
)
decoder_start_token_id
=
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
bos_token_id
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
decoder_start_token_id
)
if
input_ids
is
not
None
:
batch_size
=
input_ids
.
shape
[
0
]
# overriden by the input batch_size
...
...
@@ -831,9 +833,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert
pad_token_id
is
None
or
(
isinstance
(
pad_token_id
,
int
)
and
(
pad_token_id
>=
0
)
),
"`pad_token_id` should be a positive integer."
assert
(
decoder_start_token_id
is
not
None
or
self
.
config
.
is_encoder_decoder
is
False
),
"`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert
(
eos_token_id
is
None
)
or
(
isinstance
(
eos_token_id
,
int
)
and
(
eos_token_id
>=
0
)
),
"`eos_token_id` should be a positive integer."
...
...
@@ -912,7 +911,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
if
self
.
config
.
is_encoder_decoder
:
assert
bos_token_id
is
not
None
,
"Encoder Decoder Models need to have a bos_token_id"
if
decoder_start_token_id
is
None
:
decoder_start_token_id
=
bos_token_id
assert
(
decoder_start_token_id
is
not
None
),
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert
hasattr
(
self
,
"get_encoder"
),
"{} should have a 'get_encoder' function defined"
.
format
(
self
)
assert
callable
(
self
.
get_encoder
),
"{} should be a method"
.
format
(
self
.
get_encoder
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment