Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2ee3840
"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "f6dc2f67783082b433dfa99d4b0a8992ba64be9d"
Commit
c2ee3840
authored
Mar 13, 2020
by
Patrick von Platen
Browse files
update file to new starting token logic
parent
6a82f774
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
2 deletions
+7
-2
examples/summarization/bart/evaluate_cnn.py
examples/summarization/bart/evaluate_cnn.py
+7
-2
No files found.
examples/summarization/bart/evaluate_cnn.py
View file @
c2ee3840
...
@@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
...
@@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout
=
Path
(
out_file
).
open
(
"w"
)
fout
=
Path
(
out_file
).
open
(
"w"
)
model
=
BartForConditionalGeneration
.
from_pretrained
(
"bart-large-cnn"
,
output_past
=
True
,).
to
(
device
)
model
=
BartForConditionalGeneration
.
from_pretrained
(
"bart-large-cnn"
,
output_past
=
True
,).
to
(
device
)
tokenizer
=
BartTokenizer
.
from_pretrained
(
"bart-large"
)
tokenizer
=
BartTokenizer
.
from_pretrained
(
"bart-large"
)
max_length
=
140
min_length
=
55
for
batch
in
tqdm
(
list
(
chunks
(
lns
,
batch_size
))):
for
batch
in
tqdm
(
list
(
chunks
(
lns
,
batch_size
))):
dct
=
tokenizer
.
batch_encode_plus
(
batch
,
max_length
=
1024
,
return_tensors
=
"pt"
,
pad_to_max_length
=
True
)
dct
=
tokenizer
.
batch_encode_plus
(
batch
,
max_length
=
1024
,
return_tensors
=
"pt"
,
pad_to_max_length
=
True
)
summaries
=
model
.
generate
(
summaries
=
model
.
generate
(
...
@@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
...
@@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
attention_mask
=
dct
[
"attention_mask"
].
to
(
device
),
attention_mask
=
dct
[
"attention_mask"
].
to
(
device
),
num_beams
=
4
,
num_beams
=
4
,
length_penalty
=
2.0
,
length_penalty
=
2.0
,
max_length
=
14
2
,
# +2 from original because we start at step=1 and stop before max_length
max_length
=
max_length
+
2
,
# +2 from original because we start at step=1 and stop before max_length
min_length
=
56
,
# +1 from original because we start at step=1
min_length
=
min_length
+
1
,
# +1 from original because we start at step=1
no_repeat_ngram_size
=
3
,
no_repeat_ngram_size
=
3
,
early_stopping
=
True
,
early_stopping
=
True
,
do_sample
=
False
,
do_sample
=
False
,
decoder_start_token_id
=
model
.
config
.
eos_token_ids
[
0
]
)
)
dec
=
[
tokenizer
.
decode
(
g
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
for
g
in
summaries
]
dec
=
[
tokenizer
.
decode
(
g
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
for
g
in
summaries
]
for
hypothesis
in
dec
:
for
hypothesis
in
dec
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment