Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12f14710
Unverified
Commit
12f14710
authored
Jul 17, 2020
by
Patrick von Platen
Committed by
GitHub
Jul 17, 2020
Browse files
[Model card] Bert2Bert
Add Rouge2 results
parent
9d37c56b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
8 deletions
+43
-8
model_cards/patrickvonplaten/bert2bert-cnn_dailymail-fp16/README.md
...s/patrickvonplaten/bert2bert-cnn_dailymail-fp16/README.md
+43
-8
No files found.
model_cards/patrickvonplaten/bert2bert-cnn_dailymail-fp16/README.md
View file @
12f14710
...
@@ -125,12 +125,10 @@ def compute_metrics(pred):
...
@@ -125,12 +125,10 @@ def compute_metrics(pred):
labels_ids
=
pred
.
label_ids
labels_ids
=
pred
.
label_ids
pred_ids
=
pred
.
predictions
pred_ids
=
pred
.
predictions
pred_str
=
tokenizer
.
batch_decode
(
pred_ids
,
clean_special_tokens
=
True
)
# all unnecessary tokens are removed
label_str
=
tokenizer
.
batch_decode
(
labels_ids
,
clean_special_tokens
=
True
)
pred_str
=
tokenizer
.
batch_decode
(
pred_ids
,
skip_special_tokens
=
True
)
label_str
=
tokenizer
.
batch_decode
(
labels_ids
,
skip_special_tokens
=
True
)
pred_str
=
[
pred
.
split
(
"[CLS]"
)[
-
1
].
split
(
"[SEP]"
)[
0
]
for
pred
in
pred_str
]
label_str
=
[
label
.
split
(
"[CLS]"
)[
-
1
].
split
(
"[SEP]"
)[
0
]
for
label
in
label_str
]
rouge_output
=
rouge
.
compute
(
predictions
=
pred_str
,
references
=
label_str
,
rouge_types
=
[
"rouge2"
])[
"rouge2"
].
mid
rouge_output
=
rouge
.
compute
(
predictions
=
pred_str
,
references
=
label_str
,
rouge_types
=
[
"rouge2"
])[
"rouge2"
].
mid
return
{
return
{
...
@@ -189,6 +187,43 @@ trainer = Trainer(
...
@@ -189,6 +187,43 @@ trainer = Trainer(
trainer
.
train
()
trainer
.
train
()
```
```
## Results
## Evaluation
The following script evaluates the model on the test set of
CNN/Daily Mail.
```
python
#!/usr/bin/env python3
import
nlp
from
transformers
import
BertTokenizer
,
EncoderDecoderModel
tokenizer
=
BertTokenizer
.
from_pretrained
(
"patrickvonplaten/bert2bert-cnn_dailymail-fp16"
)
model
=
EncoderDecoderModel
.
from_pretrained
(
"patrickvonplaten/bert2bert-cnn_dailymail-fp16"
)
model
.
to
(
"cuda"
)
test_dataset
=
nlp
.
load_dataset
(
"cnn_dailymail"
,
"3.0.0"
,
split
=
"test"
)
batch_size
=
128
# map data correctly
def
generate_summary
(
batch
):
# Tokenizer will automatically set [BOS] <text> [EOS]
# cut off at BERT max length 512
inputs
=
tokenizer
(
batch
[
"article"
],
padding
=
"max_length"
,
truncation
=
True
,
max_length
=
512
,
return_tensors
=
"pt"
)
input_ids
=
inputs
.
input_ids
.
to
(
"cuda"
)
attention_mask
=
inputs
.
attention_mask
.
to
(
"cuda"
)
outputs
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
)
# all special tokens including will be removed
output_str
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
batch
[
"pred"
]
=
output_str
return
batch
results
=
test_dataset
.
map
(
generate_summary
,
batched
=
True
,
batch_size
=
batch_size
,
remove_columns
=
[
"article"
])
# load rouge for validation
rouge
=
nlp
.
load_metric
(
"rouge"
)
pred_str
=
results
[
"pred"
]
label_str
=
results
[
"highlights"
]
rouge_output
=
rouge
.
compute
(
predictions
=
pred_str
,
references
=
label_str
,
rouge_types
=
[
"rouge2"
])[
"rouge2"
].
mid
print
(
rouge_output
)
```
The obtained results should be:
TODO
| - | Rouge2 - mid -precision | Rouge2 - mid - recall | Rouge2 - mid - fmeasure |
|----------|:-------------:|:------:|:------:|
|
**CNN/Daily Mail**
| 14.12 | 14.37 |
**13.8**
|
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment