Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0344428f
Unverified
Commit
0344428f
authored
Aug 25, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 25, 2020
Browse files
[s2s] round bleu, rouge to 4 digits (#6704)
parent
b6512d23
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
12 deletions
+12
-12
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+3
-3
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+3
-3
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+3
-3
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+3
-3
No files found.
examples/seq2seq/distillation.py
View file @
0344428f
...
@@ -20,7 +20,7 @@ try:
...
@@ -20,7 +20,7 @@ try:
from
.utils
import
(
from
.utils
import
(
any_requires_grad
,
any_requires_grad
,
assert_all_frozen
,
assert_all_frozen
,
calculate_bleu
_score
,
calculate_bleu
,
freeze_params
,
freeze_params
,
pickle_load
,
pickle_load
,
use_task_specific_params
,
use_task_specific_params
,
...
@@ -32,7 +32,7 @@ except ImportError:
...
@@ -32,7 +32,7 @@ except ImportError:
from
utils
import
(
from
utils
import
(
any_requires_grad
,
any_requires_grad
,
assert_all_frozen
,
assert_all_frozen
,
calculate_bleu
_score
,
calculate_bleu
,
freeze_params
,
freeze_params
,
pickle_load
,
pickle_load
,
use_task_specific_params
,
use_task_specific_params
,
...
@@ -261,7 +261,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
...
@@ -261,7 +261,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
def
calc_generative_metrics
(
self
,
preds
,
target
)
->
dict
:
def
calc_generative_metrics
(
self
,
preds
,
target
)
->
dict
:
return
calculate_bleu
_score
(
preds
,
target
)
return
calculate_bleu
(
preds
,
target
)
@
staticmethod
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
def
add_model_specific_args
(
parser
,
root_dir
):
...
...
examples/seq2seq/finetune.py
View file @
0344428f
...
@@ -23,7 +23,7 @@ try:
...
@@ -23,7 +23,7 @@ try:
Seq2SeqDataset
,
Seq2SeqDataset
,
TranslationDataset
,
TranslationDataset
,
assert_all_frozen
,
assert_all_frozen
,
calculate_bleu
_score
,
calculate_bleu
,
calculate_rouge
,
calculate_rouge
,
flatten_list
,
flatten_list
,
freeze_params
,
freeze_params
,
...
@@ -42,7 +42,7 @@ except ImportError:
...
@@ -42,7 +42,7 @@ except ImportError:
Seq2SeqDataset
,
Seq2SeqDataset
,
TranslationDataset
,
TranslationDataset
,
assert_all_frozen
,
assert_all_frozen
,
calculate_bleu
_score
,
calculate_bleu
,
calculate_rouge
,
calculate_rouge
,
flatten_list
,
flatten_list
,
freeze_params
,
freeze_params
,
...
@@ -325,7 +325,7 @@ class TranslationModule(SummarizationModule):
...
@@ -325,7 +325,7 @@ class TranslationModule(SummarizationModule):
self
.
dataset_kwargs
[
"tgt_lang"
]
=
hparams
.
tgt_lang
self
.
dataset_kwargs
[
"tgt_lang"
]
=
hparams
.
tgt_lang
def
calc_generative_metrics
(
self
,
preds
,
target
)
->
dict
:
def
calc_generative_metrics
(
self
,
preds
,
target
)
->
dict
:
return
calculate_bleu
_score
(
preds
,
target
)
return
calculate_bleu
(
preds
,
target
)
def
main
(
args
,
model
=
None
)
->
SummarizationModule
:
def
main
(
args
,
model
=
None
)
->
SummarizationModule
:
...
...
examples/seq2seq/run_eval.py
View file @
0344428f
...
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
...
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try
:
try
:
from
.utils
import
calculate_bleu
_score
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
from
.utils
import
calculate_bleu
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
except
ImportError
:
except
ImportError
:
from
utils
import
calculate_bleu
_score
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
from
utils
import
calculate_bleu
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -103,7 +103,7 @@ def run_generate():
...
@@ -103,7 +103,7 @@ def run_generate():
if
args
.
reference_path
is
None
:
if
args
.
reference_path
is
None
:
return
return
# Compute scores
# Compute scores
score_fn
=
calculate_bleu
_score
if
"translation"
in
args
.
task
else
calculate_rouge
score_fn
=
calculate_bleu
if
"translation"
in
args
.
task
else
calculate_rouge
output_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
save_path
).
readlines
()]
output_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
save_path
).
readlines
()]
reference_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
reference_path
).
readlines
()][:
len
(
output_lns
)]
reference_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
reference_path
).
readlines
()][:
len
(
output_lns
)]
scores
:
dict
=
score_fn
(
output_lns
,
reference_lns
)
scores
:
dict
=
score_fn
(
output_lns
,
reference_lns
)
...
...
examples/seq2seq/utils.py
View file @
0344428f
...
@@ -57,9 +57,9 @@ def lmap(f: Callable, x: Iterable) -> List:
...
@@ -57,9 +57,9 @@ def lmap(f: Callable, x: Iterable) -> List:
return
list
(
map
(
f
,
x
))
return
list
(
map
(
f
,
x
))
def
calculate_bleu
_score
(
output_lns
,
refs_lns
,
**
kwargs
)
->
dict
:
def
calculate_bleu
(
output_lns
,
refs_lns
,
**
kwargs
)
->
dict
:
"""Uses sacrebleu's corpus_bleu implementation."""
"""Uses sacrebleu's corpus_bleu implementation."""
return
{
"bleu"
:
corpus_bleu
(
output_lns
,
[
refs_lns
],
**
kwargs
).
score
}
return
{
"bleu"
:
round
(
corpus_bleu
(
output_lns
,
[
refs_lns
],
**
kwargs
).
score
,
4
)
}
def
trim_batch
(
def
trim_batch
(
...
@@ -271,7 +271,7 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
...
@@ -271,7 +271,7 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
aggregator
.
add_scores
(
scores
)
aggregator
.
add_scores
(
scores
)
result
=
aggregator
.
aggregate
()
result
=
aggregator
.
aggregate
()
return
{
k
:
v
.
mid
.
fmeasure
*
100
for
k
,
v
in
result
.
items
()}
return
{
k
:
round
(
v
.
mid
.
fmeasure
*
100
,
4
)
for
k
,
v
in
result
.
items
()}
def
freeze_params
(
model
:
nn
.
Module
):
def
freeze_params
(
model
:
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment