Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5aa361f3
Unverified
Commit
5aa361f3
authored
Nov 25, 2020
by
Daniel Khashabi
Committed by
GitHub
Nov 26, 2020
Browse files
finetune.py: specifying generation min_length (#8478)
parent
30e7f7e5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
0 deletions
+6
-0
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+6
-0
No files found.
examples/seq2seq/finetune.py
View file @
5aa361f3
...
...
@@ -113,6 +113,10 @@ class SummarizationModule(BaseTransformer):
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
else
:
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
if
self
.
hparams
.
eval_min_gen_length
is
not
None
:
self
.
eval_min_length
=
self
.
hparams
.
eval_min_gen_length
else
:
self
.
eval_min_length
=
self
.
model
.
config
.
min_length
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
def
save_readable_batch
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
])
->
Dict
[
str
,
List
[
str
]]:
...
...
@@ -219,6 +223,7 @@ class SummarizationModule(BaseTransformer):
decoder_start_token_id
=
self
.
decoder_start_token_id
,
num_beams
=
self
.
eval_beams
,
max_length
=
self
.
eval_max_length
,
min_length
=
self
.
eval_min_length
,
)
gen_time
=
(
time
.
time
()
-
t0
)
/
batch
[
"input_ids"
].
shape
[
0
]
preds
:
List
[
str
]
=
self
.
ids_to_clean_text
(
generated_ids
)
...
...
@@ -346,6 +351,7 @@ class SummarizationModule(BaseTransformer):
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
,
choices
=
[
"bleu"
,
"rouge2"
,
"loss"
,
None
]
)
parser
.
add_argument
(
"--eval_max_gen_length"
,
type
=
int
,
default
=
None
,
help
=
"never generate more than n tokens"
)
parser
.
add_argument
(
"--eval_min_gen_length"
,
type
=
int
,
default
=
None
,
help
=
"never generate shorter than n tokens"
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
"--early_stopping_patience"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment