Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0203ad43
Unverified
Commit
0203ad43
authored
Sep 16, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 16, 2020
Browse files
[s2s] distributed eval cleanup (#7186)
parent
3babef81
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
31 deletions
+47
-31
examples/seq2seq/README.md
examples/seq2seq/README.md
+14
-0
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+18
-19
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+1
-1
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+14
-11
No files found.
examples/seq2seq/README.md
View file @
0203ad43
...
...
@@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
--fp16
\
--bs
32
```
### Multi-GPU Evalulation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU.
`data_dir`
must have
`{type_path}.source`
and
`{type_path}.target`
. Run
`python run_distributed_eval.py --help`
for all clargs.
```
bash
python
-m
torch.distributed.launch
--nproc_per_node
=
8 run_distributed_eval.py
\
--model_name
sshleifer/distilbart-large-xsum-12-3
\
--save_dir
xsum_generations
\
--data_dir
xsum
\
--fp16
# you can pass generate kwargs like num_beams here, just like run_eval.py
```
Contributions that implement this command for other distributed hardware setups are welcome!
#### run_eval tips and tricks
...
...
examples/seq2seq/run_distributed_eval.py
View file @
0203ad43
...
...
@@ -4,7 +4,7 @@ import time
from
json
import
JSONDecodeError
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
import
torch
from
torch.utils.data
import
DataLoader
...
...
@@ -22,7 +22,7 @@ try:
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_cl_kwargs
,
parse_numeric_
n_bool_
cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
...
...
@@ -34,7 +34,7 @@ except ImportError:
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_cl_kwargs
,
parse_numeric_
n_bool_
cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
...
...
@@ -50,7 +50,6 @@ def eval_data_dir(
type_path
=
"val"
,
n_obs
=
None
,
fp16
=
False
,
num_beams
:
int
=
4
,
task
=
"summarization"
,
local_rank
=
None
,
**
generate_kwargs
,
...
...
@@ -81,23 +80,21 @@ def eval_data_dir(
n_obs
=
n_obs
,
prefix
=
model
.
config
.
prefix
,
)
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
,
add_extra_examples
=
False
)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
,
add_extra_examples
=
False
,
shuffle
=
True
)
data_loader
=
DataLoader
(
ds
,
sampler
=
sampler
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
)
dec_kwargs
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
# tokenizer.decode
results
=
[]
for
batch
in
tqdm
(
data_loader
):
summaries
=
model
.
generate
(
input_ids
=
batch
[
"input_ids"
].
to
(
model
.
device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
model
.
device
),
num_beams
=
num_beams
,
**
generate_kwargs
,
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
**
dec_kwargs
)
labels
=
tokenizer
.
batch_decode
(
batch
[
"labels"
],
**
dec_kwargs
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
ids
=
batch
[
"ids"
]
for
i
in
range
(
len
(
labels
)):
label
,
pred
=
labels
[
i
],
preds
[
i
]
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
,
id
=
ids
[
i
].
item
()))
for
i
,
pred
in
enumerate
(
preds
):
results
.
append
(
dict
(
pred
=
pred
,
id
=
ids
[
i
].
item
()))
save_json
(
results
,
save_path
)
return
results
,
sampler
.
num_replicas
...
...
@@ -139,8 +136,8 @@ def run_generate():
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
)
start_time
=
time
.
time
()
args
,
rest
=
parser
.
parse_known_args
()
generate_kwargs
=
parse_numeric_cl_kwargs
(
rest
)
if
generate_kwargs
:
generate_kwargs
=
parse_numeric_
n_bool_
cl_kwargs
(
rest
)
if
generate_kwargs
and
args
.
local_rank
<=
0
:
print
(
f
"parsed the following generate kwargs:
{
generate_kwargs
}
"
)
json_save_dir
=
Path
(
args
.
save_dir
+
"_tmp"
)
Path
(
json_save_dir
).
mkdir
(
exist_ok
=
True
)
# this handles locking.
...
...
@@ -168,7 +165,10 @@ def run_generate():
save_dir
=
Path
(
args
.
save_dir
)
save_dir
.
mkdir
(
exist_ok
=
True
)
partial_results
=
gather_results_from_each_node
(
num_replicas
,
json_save_dir
,
args
.
sync_timeout
)
preds
,
labels
=
combine_partial_results
(
partial_results
)
preds
=
combine_partial_results
(
partial_results
)
tgt_file
=
Path
(
args
.
data_dir
).
joinpath
(
args
.
type_path
+
".target"
)
labels
=
[
x
.
rstrip
()
for
x
in
open
(
tgt_file
).
readlines
()][:
len
(
preds
)]
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu
=
"translation"
in
args
.
task
score_fn
=
calculate_bleu
if
calc_bleu
else
calculate_rouge
...
...
@@ -179,7 +179,7 @@ def run_generate():
metrics
[
"seconds_per_sample"
]
=
round
(
runtime
/
metrics
[
"n_obs"
],
2
)
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path
=
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_
{
metric_name
}
.json"
)
save_json
(
metrics
,
metrics_save_path
)
save_json
(
metrics
,
metrics_save_path
,
indent
=
None
)
print
(
metrics
)
write_txt_file
(
preds
,
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_generations.txt"
))
if
args
.
debug
:
...
...
@@ -188,15 +188,14 @@ def run_generate():
shutil
.
rmtree
(
json_save_dir
)
def
combine_partial_results
(
partial_results
)
->
Tuple
[
List
,
List
]
:
def
combine_partial_results
(
partial_results
)
->
List
:
"""Concatenate partial results into one file, then sort it by id."""
records
=
[]
for
partial_result
in
partial_results
:
records
.
extend
(
partial_result
)
records
=
list
(
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
]))
preds
=
[
x
[
"pred"
]
for
x
in
records
]
labels
=
[
x
[
"label"
]
for
x
in
records
]
return
preds
,
labels
return
preds
def
gather_results_from_each_node
(
num_replicas
,
save_dir
,
timeout
)
->
List
[
Dict
[
str
,
List
]]:
...
...
examples/seq2seq/run_eval.py
View file @
0203ad43
...
...
@@ -156,7 +156,7 @@ def run_generate(verbose=True):
scores
[
"info"
]
=
args
.
info
if
verbose
:
print
(
*
scores
)
print
(
scores
)
if
args
.
score_path
is
not
None
:
path
=
args
.
score_path
...
...
examples/seq2seq/utils.py
View file @
0203ad43
...
...
@@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset):
def
get_char_lens
(
data_file
):
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
,
**
kwargs
):
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
,
shuffle
=
True
,
**
kwargs
):
if
distributed
:
return
DistributedSortishSampler
(
self
,
batch_size
,
**
kwargs
)
return
DistributedSortishSampler
(
self
,
batch_size
,
shuffle
=
shuffle
,
**
kwargs
)
else
:
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
return
SortishSampler
(
self
.
src_lens
,
batch_size
,
shuffle
=
shuffle
)
def
__getitem__
(
self
,
item
):
raise
NotImplementedError
(
"You must implement this"
)
...
...
@@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class
SortishSampler
(
Sampler
):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def
__init__
(
self
,
data
,
batch_size
):
self
.
data
,
self
.
bs
=
data
,
batch_size
def
__init__
(
self
,
data
,
batch_size
,
shuffle
=
True
):
self
.
data
,
self
.
bs
,
self
.
shuffle
=
data
,
batch_size
,
shuffle
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
def
__iter__
(
self
):
return
iter
(
sortish_sampler_indices
(
self
.
data
,
self
.
bs
))
return
iter
(
sortish_sampler_indices
(
self
.
data
,
self
.
bs
,
shuffle
=
self
.
shuffle
))
def
sortish_sampler_indices
(
data
:
List
,
bs
:
int
)
->
np
.
array
:
def
sortish_sampler_indices
(
data
:
List
,
bs
:
int
,
shuffle
=
True
)
->
np
.
array
:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
if
not
shuffle
:
return
np
.
argsort
(
np
.
array
(
data
)
*
-
1
)
def
key_fn
(
i
):
return
data
[
i
]
...
...
@@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class
DistributedSortishSampler
(
Sampler
):
"""Copied from torch DistributedSampler"""
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
add_extra_examples
=
True
):
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
add_extra_examples
=
True
,
shuffle
=
True
):
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
...
...
@@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler):
self
.
num_samples
=
len
(
self
.
available_indices
)
self
.
batch_size
=
batch_size
self
.
add_extra_examples
=
add_extra_examples
self
.
shuffle
=
shuffle
def
__iter__
(
self
)
->
Iterable
:
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
self
.
available_indices
]
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
)
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
,
shuffle
=
self
.
shuffle
)
indices
=
[
self
.
available_indices
[
i
]
for
i
in
sortish_indices
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
...
...
@@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None:
save_json
(
repo_infos
,
os
.
path
.
join
(
folder_path
,
"git_log.json"
))
def
save_json
(
content
,
path
):
def
save_json
(
content
,
path
,
indent
=
4
,
**
json_dump_kwargs
):
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
content
,
f
,
indent
=
4
)
json
.
dump
(
content
,
f
,
indent
=
indent
,
**
json_dump_kwargs
)
def
load_json
(
path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment