Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e9a2f772
"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "56ee176c246100cdb1c0ed17338fe89704467e65"
Unverified
Commit
e9a2f772
authored
Sep 10, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 10, 2020
Browse files
[s2s] --eval_max_generate_length (#7018)
parent
df4594a9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
6 deletions
+15
-6
examples/seq2seq/distil_marian_enro_teacher.sh
examples/seq2seq/distil_marian_enro_teacher.sh
+1
-1
examples/seq2seq/distil_marian_no_teacher.sh
examples/seq2seq/distil_marian_no_teacher.sh
+1
-1
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+12
-4
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+1
-0
No files found.
examples/seq2seq/distil_marian_enro_teacher.sh
View file @
e9a2f772
...
@@ -16,5 +16,5 @@ python distillation.py \
...
@@ -16,5 +16,5 @@ python distillation.py \
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--warmup_steps
500
--logger_name
wandb
\
--warmup_steps
500
--logger_name
wandb
\
--fp16_opt_level
O1
--task
translation
--normalize_hidden
\
--fp16_opt_level
O1
--task
translation
--normalize_hidden
--num_sanity_val_steps
=
0
\
"
$@
"
"
$@
"
examples/seq2seq/distil_marian_no_teacher.sh
View file @
e9a2f772
...
@@ -13,5 +13,5 @@ python distillation.py \
...
@@ -13,5 +13,5 @@ python distillation.py \
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
$m
--model_name_or_path
$m
\
--tokenizer_name
$m
--model_name_or_path
$m
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
=
O1
--task
translation
\
--gpus
1
--fp16_opt_level
=
O1
--task
translation
--num_sanity_val_steps
=
0
\
"
$@
"
"
$@
"
examples/seq2seq/finetune.py
View file @
e9a2f772
...
@@ -11,7 +11,6 @@ from typing import Dict, List, Tuple
...
@@ -11,7 +11,6 @@ from typing import Dict, List, Tuple
import
numpy
as
np
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
from
packaging
import
version
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
...
@@ -94,6 +93,9 @@ class SummarizationModule(BaseTransformer):
...
@@ -94,6 +93,9 @@ class SummarizationModule(BaseTransformer):
"val"
:
self
.
hparams
.
val_max_target_length
,
"val"
:
self
.
hparams
.
val_max_target_length
,
"test"
:
self
.
hparams
.
test_max_target_length
,
"test"
:
self
.
hparams
.
test_max_target_length
,
}
}
if
self
.
hparams
.
sortish_sampler
and
self
.
hparams
.
gpus
>
1
:
self
.
hparams
.
sortish_sampler
=
False
warnings
.
warn
(
"ignoring sortish_sampler as it is unsupported on multiple GPUs"
)
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"val"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
assert
self
.
target_lens
[
"train"
]
<=
self
.
target_lens
[
"test"
],
f
"target_lens:
{
self
.
target_lens
}
"
...
@@ -114,6 +116,10 @@ class SummarizationModule(BaseTransformer):
...
@@ -114,6 +116,10 @@ class SummarizationModule(BaseTransformer):
)
)
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams=
{
self
.
eval_beams
}
. Need an integer > 1"
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams=
{
self
.
eval_beams
}
. Need an integer > 1"
if
self
.
hparams
.
eval_max_gen_length
is
not
None
:
self
.
eval_max_length
=
self
.
hparams
.
eval_max_gen_length
else
:
self
.
eval_max_length
=
self
.
model
.
config
.
max_length
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
def
freeze_embeds
(
self
):
def
freeze_embeds
(
self
):
...
@@ -209,12 +215,15 @@ class SummarizationModule(BaseTransformer):
...
@@ -209,12 +215,15 @@ class SummarizationModule(BaseTransformer):
def
_generative_step
(
self
,
batch
:
dict
)
->
dict
:
def
_generative_step
(
self
,
batch
:
dict
)
->
dict
:
t0
=
time
.
time
()
t0
=
time
.
time
()
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
generated_ids
=
self
.
model
.
generate
(
generated_ids
=
self
.
model
.
generate
(
batch
[
"input_ids"
],
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
],
attention_mask
=
batch
[
"attention_mask"
],
use_cache
=
True
,
use_cache
=
True
,
decoder_start_token_id
=
self
.
decoder_start_token_id
,
decoder_start_token_id
=
self
.
decoder_start_token_id
,
num_beams
=
self
.
eval_beams
,
num_beams
=
self
.
eval_beams
,
max_length
=
self
.
eval_max_length
,
)
)
gen_time
=
(
time
.
time
()
-
t0
)
/
batch
[
"input_ids"
].
shape
[
0
]
gen_time
=
(
time
.
time
()
-
t0
)
/
batch
[
"input_ids"
].
shape
[
0
]
preds
:
List
[
str
]
=
self
.
ids_to_clean_text
(
generated_ids
)
preds
:
List
[
str
]
=
self
.
ids_to_clean_text
(
generated_ids
)
...
@@ -248,7 +257,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -248,7 +257,7 @@ class SummarizationModule(BaseTransformer):
dataset
=
self
.
get_dataset
(
type_path
)
dataset
=
self
.
get_dataset
(
type_path
)
sampler
=
None
sampler
=
None
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
assert
self
.
hparams
.
gpus
<=
1
#
TODO: assert earlier
assert
self
.
hparams
.
gpus
<=
1
#
this should never break because of the assertion in __init__
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
)
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
)
shuffle
=
False
shuffle
=
False
...
@@ -321,6 +330,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -321,6 +330,7 @@ class SummarizationModule(BaseTransformer):
parser
.
add_argument
(
parser
.
add_argument
(
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
,
choices
=
[
"bleu"
,
"rouge2"
,
"loss"
,
None
]
"--val_metric"
,
type
=
str
,
default
=
None
,
required
=
False
,
choices
=
[
"bleu"
,
"rouge2"
,
"loss"
,
None
]
)
)
parser
.
add_argument
(
"--eval_max_gen_length"
,
type
=
int
,
default
=
None
,
help
=
"never generate more than n tokens"
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
"--save_top_k"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many checkpoints to save"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--early_stopping_patience"
,
"--early_stopping_patience"
,
...
@@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule:
...
@@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule:
model
:
SummarizationModule
=
SummarizationModule
(
args
)
model
:
SummarizationModule
=
SummarizationModule
(
args
)
else
:
else
:
model
:
SummarizationModule
=
TranslationModule
(
args
)
model
:
SummarizationModule
=
TranslationModule
(
args
)
if
version
.
parse
(
torch
.
__version__
)
==
version
.
parse
(
"1.6"
)
and
args
.
fp16
:
warnings
.
warn
(
"FP16 only seems to work with torch 1.5+apex"
)
dataset
=
Path
(
args
.
data_dir
).
name
dataset
=
Path
(
args
.
data_dir
).
name
if
(
if
(
args
.
logger_name
==
"default"
args
.
logger_name
==
"default"
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
e9a2f772
...
@@ -34,6 +34,7 @@ CHEAP_ARGS = {
...
@@ -34,6 +34,7 @@ CHEAP_ARGS = {
"supervise_forward"
:
True
,
"supervise_forward"
:
True
,
"normalize_hidden"
:
True
,
"normalize_hidden"
:
True
,
"label_smoothing"
:
0.2
,
"label_smoothing"
:
0.2
,
"eval_max_gen_length"
:
None
,
"eval_beams"
:
1
,
"eval_beams"
:
1
,
"val_metric"
:
"loss"
,
"val_metric"
:
"loss"
,
"save_top_k"
:
1
,
"save_top_k"
:
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment