Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e9a2f772
Unverified
Commit
e9a2f772
authored
Sep 10, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 10, 2020
Browse files
[s2s] --eval_max_generate_length (#7018)
parent
df4594a9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
6 deletions
+15
-6
examples/seq2seq/distil_marian_enro_teacher.sh
examples/seq2seq/distil_marian_enro_teacher.sh
+1
-1
examples/seq2seq/distil_marian_no_teacher.sh
examples/seq2seq/distil_marian_no_teacher.sh
+1
-1
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+12
-4
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+1
-0
No files found.
examples/seq2seq/distil_marian_enro_teacher.sh
View file @
e9a2f772
...
@@ -16,5 +16,5 @@ python distillation.py \
...
@@ -16,5 +16,5 @@ python distillation.py \
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--warmup_steps
500
--logger_name
wandb
\
--warmup_steps
500
--logger_name
wandb
\
--fp16_opt_level
O1
--task
translation
--normalize_hidden
\
--fp16_opt_level
O1
--task
translation
--normalize_hidden
--num_sanity_val_steps
=
0
\
"
$@
"
"
$@
"
examples/seq2seq/distil_marian_no_teacher.sh
View file @
e9a2f772
...
@@ -13,5 +13,5 @@ python distillation.py \
...
@@ -13,5 +13,5 @@ python distillation.py \
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
$m
--model_name_or_path
$m
\
--tokenizer_name
$m
--model_name_or_path
$m
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
=
O1
--task
translation
\
--gpus
1
--fp16_opt_level
=
O1
--task
translation
--num_sanity_val_steps
=
0
\
"
$@
"
"
$@
"
examples/seq2seq/finetune.py
View file @
e9a2f772
...
@@ -11,7 +11,6 @@ from typing import Dict, List, Tuple
...
@@ -11,7 +11,6 @@ from typing import Dict, List, Tuple
import
numpy
as
np
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
packaging
import
version
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
...
@@ -94,6 +93,9 @@ class SummarizationModule(BaseTransformer):
...
@@ -94,6 +93,9 @@ class SummarizationModule(BaseTransformer):
"val"
:
self
.
hparams
.
val_max_target_length
,
"val"
:
self
.
hparams
.
val_max_target_length
,
"test"
:
self
.
hparams
.
test_max_target_length
,
"test"
:
self
.
hparams
.
test_max_target_length
,
}
}
if
self
.
hparams
.
sortish_sampler
and
self
.
hparams
.
gpus
>
1
:
self
.
hparams
.
sortish_sampler
=
False
warnings
.
warn
(
"ignoring sortish_sampler as it is unsupported on multiple GPUs"
)
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
...
@@ -114,6 +116,10 @@ class SummarizationModule(BaseTransformer):
...
@@ -114,6 +116,10 @@ class SummarizationModule(BaseTransformer):
)
)
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams=
{
self
.
eval_beams
}
. Need an integer > 1"
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams=
{
self
.
eval_beams
}
. Need an integer > 1"
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
else
:
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
def
freeze_embeds
(
self
):
def
freeze_embeds
(
self
):
...
@@ -209,12 +215,15 @@ class SummarizationModule(BaseTransformer):
...
@@ -209,12 +215,15 @@ class SummarizationModule(BaseTransformer):
def
_generative_step
(
self
,
batch
:
dict
)
->
dict
:
def
_generative_step
(
self
,
batch
:
dict
)
->
dict
:
t0
=
time
.
time
()
t0
=
time
.
time
()
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
generated_ids
=
self
.
model
.
generate
(
generated_ids
=
self
.
model
.
generate
(
batch
[
"input_ids"
],
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
],
attention_mask
=
batch
[
"attention_mask"
],
use_cache
=
True
,
use_cache
=
True
,
decoder_start_token_id
=
self
.
decoder_start_token_id
,
decoder_start_token_id
=
self
.
decoder_start_token_id
,
num_beams
=
self
.
eval_beams
,
num_beams
=
self
.
eval_beams
,
max_length
=
self
.
eval_max_length
,
)
)
gen_time
=
(
time
.
time
()
-
t0
)
/
batch
[
"input_ids"
].
shape
[
0
]
gen_time
=
(
time
.
time
()
-
t0
)
/
batch
[
"input_ids"
].
shape
[
0
]
preds
:
List
[
str
]
=
self
.
ids_to_clean_text
(
generated_ids
)
preds
:
List
[
str
]
=
self
.
ids_to_clean_text
(
generated_ids
)
...
@@ -248,7 +257,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -248,7 +257,7 @@ class SummarizationModule(BaseTransformer):
dataset
=
self
.
get_dataset
(
type_path
)
dataset
=
self
.
get_dataset
(
type_path
)
sampler
=
None
sampler
=
None
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
assert
self
.
hparams
.
gpus
<=
1
#
TODO: assert earlier
assert
self
.
hparams
.
gpus
<=
1
#
this should never break because of the assertion in __init__
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
)
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
)
shuffle
=
False
shuffle
=
False
...
@@ -321,6 +330,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -321,6 +330,7 @@ class SummarizationModule(BaseTransformer):
parser
.
add_argument
(
parser
.
add_argument
(
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
,
choices
=
[
"bleu"
,
"rouge2"
,
"loss"
,
None
]
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
,
choices
=
[
"bleu"
,
"rouge2"
,
"loss"
,
None
]
)
)
parser
.
add_argument
(
"--eval_max_gen_length"
,
type
=
int
,
default
=
None
,
help
=
"never generate more than n tokens"
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--early_stopping_patience"
,
"--early_stopping_patience"
,
...
@@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule:
...
@@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule:
model
:
SummarizationModule
=
SummarizationModule
(
args
)
model
:
SummarizationModule
=
SummarizationModule
(
args
)
else
:
else
:
model
:
SummarizationModule
=
TranslationModule
(
args
)
model
:
SummarizationModule
=
TranslationModule
(
args
)
if
version
.
parse
(
torch
.
__version__
)
==
version
.
parse
(
"1.6"
)
and
args
.
fp16
:
warnings
.
warn
(
"FP16 only seems to work with torch 1.5+apex"
)
dataset
=
Path
(
args
.
data_dir
).
name
dataset
=
Path
(
args
.
data_dir
).
name
if
(
if
(
args
.
logger_name
==
"default"
args
.
logger_name
==
"default"
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
e9a2f772
...
@@ -34,6 +34,7 @@ CHEAP_ARGS = {
...
@@ -34,6 +34,7 @@ CHEAP_ARGS = {
"supervise_forward"
:
True
,
"supervise_forward"
:
True
,
"normalize_hidden"
:
True
,
"normalize_hidden"
:
True
,
"label_smoothing"
:
0.2
,
"label_smoothing"
:
0.2
,
"eval_max_gen_length"
:
None
,
"eval_beams"
:
1
,
"eval_beams"
:
1
,
"val_metric"
:
"loss"
,
"val_metric"
:
"loss"
,
"save_top_k"
:
1
,
"save_top_k"
:
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment