Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d9d0f114
"...resnet50_tensorflow.git" did not exist on "5a3c97b912642c2284cc8234ba35bc5a2bda9f5c"
Unverified
Commit
d9d0f114
authored
Sep 24, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 24, 2020
Browse files
[s2s] distributed eval allows num_return_sequences > 1 (#7254)
parent
0804d077
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
11 deletions
+31
-11
examples/seq2seq/README.md
examples/seq2seq/README.md
+2
-2
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+21
-2
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+1
-7
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+1
-0
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+6
-0
No files found.
examples/seq2seq/README.md
View file @
d9d0f114
...
...
@@ -235,7 +235,7 @@ export DATA_DIR=cnn_dm
--fp16
\
--bs
32
```
### Multi-GPU Evalu
l
ation
### Multi-GPU Evaluation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU.
`data_dir`
must have
`{type_path}.source`
and
`{type_path}.target`
. Run
`./run_distributed_eval.py --help`
for all clargs.
...
...
@@ -250,7 +250,7 @@ python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
Contributions that implement this command for other distributed hardware setups are welcome!
####
run_eval t
ips and
t
ricks
####
Single-GPU Eval: T
ips and
T
ricks
When using
`run_eval.py`
, the following features can be useful:
...
...
examples/seq2seq/run_distributed_eval.py
View file @
d9d0f114
...
...
@@ -17,6 +17,7 @@ from utils import (
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
chunks
,
lmap
,
load_json
,
parse_numeric_n_bool_cl_kwargs
,
...
...
@@ -40,6 +41,7 @@ def eval_data_dir(
fp16
=
False
,
task
=
"summarization"
,
local_rank
=
None
,
num_return_sequences
=
1
,
src_lang
=
None
,
tgt_lang
=
None
,
prefix
=
""
,
...
...
@@ -56,10 +58,15 @@ def eval_data_dir(
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
cuda
()
if
fp16
:
model
=
model
.
half
()
# determine if we need to increase num_beams
use_task_specific_params
(
model
,
task
)
# update config with task specific params
num_beams
=
generate_kwargs
.
pop
(
"num_beams"
,
model
.
config
.
num_beams
)
# AttributeError risk?
if
num_return_sequences
>
num_beams
:
num_beams
=
num_return_sequences
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
logger
.
info
(
f
"Inferred tokenizer type:
{
tokenizer
.
__class__
}
"
)
# if this is wrong, check config.model_type.
use_task_specific_params
(
model
,
task
)
# update config with task specific params
if
max_source_length
is
None
:
max_source_length
=
tokenizer
.
model_max_length
if
prefix
is
None
:
...
...
@@ -84,10 +91,14 @@ def eval_data_dir(
summaries
=
model
.
generate
(
input_ids
=
batch
[
"input_ids"
].
to
(
model
.
device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
model
.
device
),
num_return_sequences
=
num_return_sequences
,
num_beams
=
num_beams
,
**
generate_kwargs
,
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
ids
=
batch
[
"ids"
]
if
num_return_sequences
>
1
:
preds
=
chunks
(
preds
,
num_return_sequences
)
# batch size chunks, each of size num_return_seq
for
i
,
pred
in
enumerate
(
preds
):
results
.
append
(
dict
(
pred
=
pred
,
id
=
ids
[
i
].
item
()))
save_json
(
results
,
save_path
)
...
...
@@ -110,7 +121,6 @@ def run_generate():
parser
.
add_argument
(
"--type_path"
,
type
=
str
,
default
=
"test"
,
help
=
"which subset to evaluate typically train/val/test"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
...
...
@@ -120,6 +130,9 @@ def run_generate():
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"How many observations. Defaults to all."
)
parser
.
add_argument
(
"--num_return_sequences"
,
type
=
int
,
default
=
1
,
required
=
False
,
help
=
"How many sequences to return"
)
parser
.
add_argument
(
"--sync_timeout"
,
type
=
int
,
...
...
@@ -158,6 +171,7 @@ def run_generate():
local_rank
=
args
.
local_rank
,
n_obs
=
args
.
n_obs
,
max_source_length
=
args
.
max_source_length
,
num_return_sequences
=
args
.
num_return_sequences
,
prefix
=
args
.
prefix
,
src_lang
=
args
.
src_lang
,
tgt_lang
=
args
.
tgt_lang
,
...
...
@@ -169,6 +183,11 @@ def run_generate():
save_dir
.
mkdir
(
exist_ok
=
True
)
partial_results
=
gather_results_from_each_node
(
num_replicas
,
json_save_dir
,
args
.
sync_timeout
)
preds
=
combine_partial_results
(
partial_results
)
if
args
.
num_return_sequences
>
1
:
save_path
=
save_dir
.
joinpath
(
"pseudolabel_results.json"
)
print
(
f
"Saving aggregated results at
{
save_path
}
, intermediate in
{
json_save_dir
}
/"
)
save_json
(
preds
,
save_path
)
return
tgt_file
=
Path
(
args
.
data_dir
).
joinpath
(
args
.
type_path
+
".target"
)
labels
=
[
x
.
rstrip
()
for
x
in
open
(
tgt_file
).
readlines
()][:
len
(
preds
)]
...
...
examples/seq2seq/run_eval.py
View file @
d9d0f114
...
...
@@ -13,7 +13,7 @@ import torch
from
tqdm
import
tqdm
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
from
utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_n_bool_cl_kwargs
,
use_task_specific_params
from
utils
import
calculate_bleu
,
calculate_rouge
,
chunks
,
parse_numeric_n_bool_cl_kwargs
,
use_task_specific_params
logger
=
getLogger
(
__name__
)
...
...
@@ -22,12 +22,6 @@ logger = getLogger(__name__)
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
chunks
(
lst
,
n
):
"""Yield successive n-sized chunks from lst."""
for
i
in
range
(
0
,
len
(
lst
),
n
):
yield
lst
[
i
:
i
+
n
]
def
generate_summaries_or_translations
(
examples
:
List
[
str
],
out_file
:
str
,
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
d9d0f114
...
...
@@ -145,6 +145,7 @@ class TestSummarizationDistiller(unittest.TestCase):
assert
not
failures
,
f
"The following models could not be loaded through AutoConfig:
{
failures
}
"
@
require_multigpu
@
unittest
.
skip
(
"Broken at the moment"
)
def
test_multigpu
(
self
):
updates
=
dict
(
no_teacher
=
True
,
...
...
examples/seq2seq/utils.py
View file @
d9d0f114
...
...
@@ -456,3 +456,9 @@ def write_txt_file(ordered_tgt, path):
for
ln
in
ordered_tgt
:
f
.
write
(
ln
+
"
\n
"
)
f
.
flush
()
def
chunks
(
lst
,
n
):
"""Yield successive n-sized chunks from lst."""
for
i
in
range
(
0
,
len
(
lst
),
n
):
yield
lst
[
i
:
i
+
n
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment