Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0344428f
Unverified
Commit
0344428f
authored
Aug 25, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 25, 2020
Browse files
[s2s] round bleu, rouge to 4 digits (#6704)
parent
b6512d23
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
12 deletions
+12
-12
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+3
-3
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+3
-3
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+3
-3
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+3
-3
No files found.
examples/seq2seq/distillation.py
View file @
0344428f
...
...
@@ -20,7 +20,7 @@ try:
from
.utils
import
(
any_requires_grad
,
assert_all_frozen
,
calculate_bleu
_score
,
calculate_bleu
,
freeze_params
,
pickle_load
,
use_task_specific_params
,
...
...
@@ -32,7 +32,7 @@ except ImportError:
from
utils
import
(
any_requires_grad
,
assert_all_frozen
,
calculate_bleu
_score
,
calculate_bleu
,
freeze_params
,
pickle_load
,
use_task_specific_params
,
...
...
@@ -261,7 +261,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
def
calc_generative_metrics
(
self
,
preds
,
target
)
->
dict
:
return
calculate_bleu
_score
(
preds
,
target
)
return
calculate_bleu
(
preds
,
target
)
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
...
...
examples/seq2seq/finetune.py
View file @
0344428f
...
...
@@ -23,7 +23,7 @@ try:
Seq2SeqDataset
,
TranslationDataset
,
assert_all_frozen
,
calculate_bleu
_score
,
calculate_bleu
,
calculate_rouge
,
flatten_list
,
freeze_params
,
...
...
@@ -42,7 +42,7 @@ except ImportError:
Seq2SeqDataset
,
TranslationDataset
,
assert_all_frozen
,
calculate_bleu
_score
,
calculate_bleu
,
calculate_rouge
,
flatten_list
,
freeze_params
,
...
...
@@ -325,7 +325,7 @@ class TranslationModule(SummarizationModule):
self
.
dataset_kwargs
[
"tgt_lang"
]
=
hparams
.
tgt_lang
def
calc_generative_metrics
(
self
,
preds
,
target
)
->
dict
:
return
calculate_bleu
_score
(
preds
,
target
)
return
calculate_bleu
(
preds
,
target
)
def
main
(
args
,
model
=
None
)
->
SummarizationModule
:
...
...
examples/seq2seq/run_eval.py
View file @
0344428f
...
...
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try
:
from
.utils
import
calculate_bleu
_score
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
from
.utils
import
calculate_bleu
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
except
ImportError
:
from
utils
import
calculate_bleu
_score
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
from
utils
import
calculate_bleu
,
calculate_rouge
,
trim_batch
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
...
@@ -103,7 +103,7 @@ def run_generate():
if
args
.
reference_path
is
None
:
return
# Compute scores
score_fn
=
calculate_bleu
_score
if
"translation"
in
args
.
task
else
calculate_rouge
score_fn
=
calculate_bleu
if
"translation"
in
args
.
task
else
calculate_rouge
output_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
save_path
).
readlines
()]
reference_lns
=
[
x
.
rstrip
()
for
x
in
open
(
args
.
reference_path
).
readlines
()][:
len
(
output_lns
)]
scores
:
dict
=
score_fn
(
output_lns
,
reference_lns
)
...
...
examples/seq2seq/utils.py
View file @
0344428f
...
...
@@ -57,9 +57,9 @@ def lmap(f: Callable, x: Iterable) -> List:
return
list
(
map
(
f
,
x
))
def
calculate_bleu
_score
(
output_lns
,
refs_lns
,
**
kwargs
)
->
dict
:
def
calculate_bleu
(
output_lns
,
refs_lns
,
**
kwargs
)
->
dict
:
"""Uses sacrebleu's corpus_bleu implementation."""
return
{
"bleu"
:
corpus_bleu
(
output_lns
,
[
refs_lns
],
**
kwargs
).
score
}
return
{
"bleu"
:
round
(
corpus_bleu
(
output_lns
,
[
refs_lns
],
**
kwargs
).
score
,
4
)
}
def
trim_batch
(
...
...
@@ -271,7 +271,7 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
aggregator
.
add_scores
(
scores
)
result
=
aggregator
.
aggregate
()
return
{
k
:
v
.
mid
.
fmeasure
*
100
for
k
,
v
in
result
.
items
()}
return
{
k
:
round
(
v
.
mid
.
fmeasure
*
100
,
4
)
for
k
,
v
in
result
.
items
()}
def
freeze_params
(
model
:
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment