Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0203ad43
Unverified
Commit
0203ad43
authored
Sep 16, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 16, 2020
Browse files
[s2s] distributed eval cleanup (#7186)
parent
3babef81
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
31 deletions
+47
-31
examples/seq2seq/README.md
examples/seq2seq/README.md
+14
-0
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+18
-19
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+1
-1
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+14
-11
No files found.
examples/seq2seq/README.md
View file @
0203ad43
...
@@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
...
@@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
--fp16
\
--fp16
\
--bs
32
--bs
32
```
```
### Multi-GPU Evalulation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU.
`data_dir`
must have
`{type_path}.source`
and
`{type_path}.target`
. Run
`python run_distributed_eval.py --help`
for all clargs.
```
bash
python
-m
torch.distributed.launch
--nproc_per_node
=
8 run_distributed_eval.py
\
--model_name
sshleifer/distilbart-large-xsum-12-3
\
--save_dir
xsum_generations
\
--data_dir
xsum
\
--fp16
# you can pass generate kwargs like num_beams here, just like run_eval.py
```
Contributions that implement this command for other distributed hardware setups are welcome!
#### run_eval tips and tricks
#### run_eval tips and tricks
...
...
examples/seq2seq/run_distributed_eval.py
View file @
0203ad43
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
from
json
import
JSONDecodeError
from
json
import
JSONDecodeError
from
logging
import
getLogger
from
logging
import
getLogger
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
@@ -22,7 +22,7 @@ try:
...
@@ -22,7 +22,7 @@ try:
calculate_rouge
,
calculate_rouge
,
lmap
,
lmap
,
load_json
,
load_json
,
parse_numeric_cl_kwargs
,
parse_numeric_
n_bool_
cl_kwargs
,
save_json
,
save_json
,
use_task_specific_params
,
use_task_specific_params
,
write_txt_file
,
write_txt_file
,
...
@@ -34,7 +34,7 @@ except ImportError:
...
@@ -34,7 +34,7 @@ except ImportError:
calculate_rouge
,
calculate_rouge
,
lmap
,
lmap
,
load_json
,
load_json
,
parse_numeric_cl_kwargs
,
parse_numeric_
n_bool_
cl_kwargs
,
save_json
,
save_json
,
use_task_specific_params
,
use_task_specific_params
,
write_txt_file
,
write_txt_file
,
...
@@ -50,7 +50,6 @@ def eval_data_dir(
...
@@ -50,7 +50,6 @@ def eval_data_dir(
type_path
=
"val"
,
type_path
=
"val"
,
n_obs
=
None
,
n_obs
=
None
,
fp16
=
False
,
fp16
=
False
,
num_beams
:
int
=
4
,
task
=
"summarization"
,
task
=
"summarization"
,
local_rank
=
None
,
local_rank
=
None
,
**
generate_kwargs
,
**
generate_kwargs
,
...
@@ -81,23 +80,21 @@ def eval_data_dir(
...
@@ -81,23 +80,21 @@ def eval_data_dir(
n_obs
=
n_obs
,
n_obs
=
n_obs
,
prefix
=
model
.
config
.
prefix
,
prefix
=
model
.
config
.
prefix
,
)
)
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
,
add_extra_examples
=
False
)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
,
add_extra_examples
=
False
,
shuffle
=
True
)
data_loader
=
DataLoader
(
ds
,
sampler
=
sampler
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
)
data_loader
=
DataLoader
(
ds
,
sampler
=
sampler
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
)
dec_kwargs
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
# tokenizer.decode
results
=
[]
results
=
[]
for
batch
in
tqdm
(
data_loader
):
for
batch
in
tqdm
(
data_loader
):
summaries
=
model
.
generate
(
summaries
=
model
.
generate
(
input_ids
=
batch
[
"input_ids"
].
to
(
model
.
device
),
input_ids
=
batch
[
"input_ids"
].
to
(
model
.
device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
model
.
device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
model
.
device
),
num_beams
=
num_beams
,
**
generate_kwargs
,
**
generate_kwargs
,
)
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
**
dec_kwargs
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
labels
=
tokenizer
.
batch_decode
(
batch
[
"labels"
],
**
dec_kwargs
)
ids
=
batch
[
"ids"
]
ids
=
batch
[
"ids"
]
for
i
in
range
(
len
(
labels
)):
for
i
,
pred
in
enumerate
(
preds
):
label
,
pred
=
labels
[
i
],
preds
[
i
]
results
.
append
(
dict
(
pred
=
pred
,
id
=
ids
[
i
].
item
()))
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
,
id
=
ids
[
i
].
item
()))
save_json
(
results
,
save_path
)
save_json
(
results
,
save_path
)
return
results
,
sampler
.
num_replicas
return
results
,
sampler
.
num_replicas
...
@@ -139,8 +136,8 @@ def run_generate():
...
@@ -139,8 +136,8 @@ def run_generate():
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
args
,
rest
=
parser
.
parse_known_args
()
args
,
rest
=
parser
.
parse_known_args
()
generate_kwargs
=
parse_numeric_cl_kwargs
(
rest
)
generate_kwargs
=
parse_numeric_
n_bool_
cl_kwargs
(
rest
)
if
generate_kwargs
:
if
generate_kwargs
and
args
.
local_rank
<=
0
:
print
(
f
"parsed the following generate kwargs:
{
generate_kwargs
}
"
)
print
(
f
"parsed the following generate kwargs:
{
generate_kwargs
}
"
)
json_save_dir
=
Path
(
args
.
save_dir
+
"_tmp"
)
json_save_dir
=
Path
(
args
.
save_dir
+
"_tmp"
)
Path
(
json_save_dir
).
mkdir
(
exist_ok
=
True
)
# this handles locking.
Path
(
json_save_dir
).
mkdir
(
exist_ok
=
True
)
# this handles locking.
...
@@ -168,7 +165,10 @@ def run_generate():
...
@@ -168,7 +165,10 @@ def run_generate():
save_dir
=
Path
(
args
.
save_dir
)
save_dir
=
Path
(
args
.
save_dir
)
save_dir
.
mkdir
(
exist_ok
=
True
)
save_dir
.
mkdir
(
exist_ok
=
True
)
partial_results
=
gather_results_from_each_node
(
num_replicas
,
json_save_dir
,
args
.
sync_timeout
)
partial_results
=
gather_results_from_each_node
(
num_replicas
,
json_save_dir
,
args
.
sync_timeout
)
preds
,
labels
=
combine_partial_results
(
partial_results
)
preds
=
combine_partial_results
(
partial_results
)
tgt_file
=
Path
(
args
.
data_dir
).
joinpath
(
args
.
type_path
+
".target"
)
labels
=
[
x
.
rstrip
()
for
x
in
open
(
tgt_file
).
readlines
()][:
len
(
preds
)]
# Calculate metrics, save metrics, and save _generations.txt
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu
=
"translation"
in
args
.
task
calc_bleu
=
"translation"
in
args
.
task
score_fn
=
calculate_bleu
if
calc_bleu
else
calculate_rouge
score_fn
=
calculate_bleu
if
calc_bleu
else
calculate_rouge
...
@@ -179,7 +179,7 @@ def run_generate():
...
@@ -179,7 +179,7 @@ def run_generate():
metrics
[
"seconds_per_sample"
]
=
round
(
runtime
/
metrics
[
"n_obs"
],
2
)
metrics
[
"seconds_per_sample"
]
=
round
(
runtime
/
metrics
[
"n_obs"
],
2
)
# TODO(@stas00): add whatever metadata to metrics
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path
=
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_
{
metric_name
}
.json"
)
metrics_save_path
=
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_
{
metric_name
}
.json"
)
save_json
(
metrics
,
metrics_save_path
)
save_json
(
metrics
,
metrics_save_path
,
indent
=
None
)
print
(
metrics
)
print
(
metrics
)
write_txt_file
(
preds
,
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_generations.txt"
))
write_txt_file
(
preds
,
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_generations.txt"
))
if
args
.
debug
:
if
args
.
debug
:
...
@@ -188,15 +188,14 @@ def run_generate():
...
@@ -188,15 +188,14 @@ def run_generate():
shutil
.
rmtree
(
json_save_dir
)
shutil
.
rmtree
(
json_save_dir
)
def
combine_partial_results
(
partial_results
)
->
Tuple
[
List
,
List
]
:
def
combine_partial_results
(
partial_results
)
->
List
:
"""Concatenate partial results into one file, then sort it by id."""
"""Concatenate partial results into one file, then sort it by id."""
records
=
[]
records
=
[]
for
partial_result
in
partial_results
:
for
partial_result
in
partial_results
:
records
.
extend
(
partial_result
)
records
.
extend
(
partial_result
)
records
=
list
(
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
]))
records
=
list
(
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
]))
preds
=
[
x
[
"pred"
]
for
x
in
records
]
preds
=
[
x
[
"pred"
]
for
x
in
records
]
labels
=
[
x
[
"label"
]
for
x
in
records
]
return
preds
return
preds
,
labels
def
gather_results_from_each_node
(
num_replicas
,
save_dir
,
timeout
)
->
List
[
Dict
[
str
,
List
]]:
def
gather_results_from_each_node
(
num_replicas
,
save_dir
,
timeout
)
->
List
[
Dict
[
str
,
List
]]:
...
...
examples/seq2seq/run_eval.py
View file @
0203ad43
...
@@ -156,7 +156,7 @@ def run_generate(verbose=True):
...
@@ -156,7 +156,7 @@ def run_generate(verbose=True):
scores
[
"info"
]
=
args
.
info
scores
[
"info"
]
=
args
.
info
if
verbose
:
if
verbose
:
print
(
*
scores
)
print
(
scores
)
if
args
.
score_path
is
not
None
:
if
args
.
score_path
is
not
None
:
path
=
args
.
score_path
path
=
args
.
score_path
...
...
examples/seq2seq/utils.py
View file @
0203ad43
...
@@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset):
...
@@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset):
def
get_char_lens
(
data_file
):
def
get_char_lens
(
data_file
):
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
,
**
kwargs
):
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
,
shuffle
=
True
,
**
kwargs
):
if
distributed
:
if
distributed
:
return
DistributedSortishSampler
(
self
,
batch_size
,
**
kwargs
)
return
DistributedSortishSampler
(
self
,
batch_size
,
shuffle
=
shuffle
,
**
kwargs
)
else
:
else
:
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
return
SortishSampler
(
self
.
src_lens
,
batch_size
,
shuffle
=
shuffle
)
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
raise
NotImplementedError
(
"You must implement this"
)
raise
NotImplementedError
(
"You must implement this"
)
...
@@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class
SortishSampler
(
Sampler
):
class
SortishSampler
(
Sampler
):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def
__init__
(
self
,
data
,
batch_size
):
def
__init__
(
self
,
data
,
batch_size
,
shuffle
=
True
):
self
.
data
,
self
.
bs
=
data
,
batch_size
self
.
data
,
self
.
bs
,
self
.
shuffle
=
data
,
batch_size
,
shuffle
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
return
len
(
self
.
data
)
def
__iter__
(
self
):
def
__iter__
(
self
):
return
iter
(
sortish_sampler_indices
(
self
.
data
,
self
.
bs
))
return
iter
(
sortish_sampler_indices
(
self
.
data
,
self
.
bs
,
shuffle
=
self
.
shuffle
))
def
sortish_sampler_indices
(
data
:
List
,
bs
:
int
)
->
np
.
array
:
def
sortish_sampler_indices
(
data
:
List
,
bs
:
int
,
shuffle
=
True
)
->
np
.
array
:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
if
not
shuffle
:
return
np
.
argsort
(
np
.
array
(
data
)
*
-
1
)
def
key_fn
(
i
):
def
key_fn
(
i
):
return
data
[
i
]
return
data
[
i
]
...
@@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
...
@@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class
DistributedSortishSampler
(
Sampler
):
class
DistributedSortishSampler
(
Sampler
):
"""Copied from torch DistributedSampler"""
"""Copied from torch DistributedSampler"""
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
add_extra_examples
=
True
):
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
add_extra_examples
=
True
,
shuffle
=
True
):
if
num_replicas
is
None
:
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
raise
RuntimeError
(
"Requires distributed package to be available"
)
...
@@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler):
...
@@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler):
self
.
num_samples
=
len
(
self
.
available_indices
)
self
.
num_samples
=
len
(
self
.
available_indices
)
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
add_extra_examples
=
add_extra_examples
self
.
add_extra_examples
=
add_extra_examples
self
.
shuffle
=
shuffle
def
__iter__
(
self
)
->
Iterable
:
def
__iter__
(
self
)
->
Iterable
:
g
=
torch
.
Generator
()
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
g
.
manual_seed
(
self
.
epoch
)
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
self
.
available_indices
]
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
self
.
available_indices
]
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
)
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
,
shuffle
=
self
.
shuffle
)
indices
=
[
self
.
available_indices
[
i
]
for
i
in
sortish_indices
]
indices
=
[
self
.
available_indices
[
i
]
for
i
in
sortish_indices
]
assert
len
(
indices
)
==
self
.
num_samples
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
return
iter
(
indices
)
...
@@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None:
...
@@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None:
save_json
(
repo_infos
,
os
.
path
.
join
(
folder_path
,
"git_log.json"
))
save_json
(
repo_infos
,
os
.
path
.
join
(
folder_path
,
"git_log.json"
))
def
save_json
(
content
,
path
):
def
save_json
(
content
,
path
,
indent
=
4
,
**
json_dump_kwargs
):
with
open
(
path
,
"w"
)
as
f
:
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
content
,
f
,
indent
=
4
)
json
.
dump
(
content
,
f
,
indent
=
indent
,
**
json_dump_kwargs
)
def
load_json
(
path
):
def
load_json
(
path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment