Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
33d479d2
Unverified
Commit
33d479d2
authored
Sep 14, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 14, 2020
Browse files
[s2s] distributed eval in one command (#7124)
parent
206b78d4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
85 deletions
+125
-85
examples/seq2seq/aggregate_distributed_results.py
examples/seq2seq/aggregate_distributed_results.py
+0
-46
examples/seq2seq/romanian_postprocessing.md
examples/seq2seq/romanian_postprocessing.md
+1
-1
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+104
-22
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+20
-16
No files found.
examples/seq2seq/aggregate_distributed_results.py
deleted
100644 → 0
View file @
206b78d4
from
pathlib
import
Path
import
fire
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
load_json
,
save_json
,
write_txt_file
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
load_json
,
save_json
,
write_txt_file
def
combine_partial_results
(
result_dir
:
str
,
save_dir
:
str
=
None
,
save_prefix
=
None
,
calc_bleu
=
False
,
just_metrics
=
False
):
"""Write first n lines of each file f in src_dir to dest_dir/f """
src_dir
=
Path
(
result_dir
)
save_dir
=
Path
(
save_dir
)
save_dir
.
mkdir
(
exist_ok
=
True
)
paths_to_combine
=
list
(
src_dir
.
glob
(
"rank*.json"
))
records
=
[]
for
partial_result
in
paths_to_combine
:
records
.
extend
(
load_json
(
partial_result
))
preds
=
[
x
[
"pred"
]
for
x
in
records
]
labels
=
[
x
[
"label"
]
for
x
in
records
]
score_fn
=
calculate_bleu
if
calc_bleu
else
calculate_rouge
metrics
=
score_fn
(
preds
,
labels
)
save_json
(
metrics
,
save_dir
.
joinpath
(
"metrics.json"
))
# better would be be {prefix}_{rouge|bleu}.json
print
(
metrics
)
if
just_metrics
:
return
if
save_prefix
is
None
:
save_prefix
=
"generated"
print
(
"using generated as prefix"
)
tgt_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.target"
)
write_txt_file
(
labels
,
tgt_path
)
pred_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.pred_target"
)
write_txt_file
(
preds
,
pred_path
)
if
"source"
in
records
[
0
]:
src_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.source"
)
write_txt_file
([
x
[
"source"
]
for
x
in
records
],
src_path
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
combine_partial_results
)
examples/seq2seq/romanian_postprocessing.md
View file @
33d479d2
...
...
@@ -12,7 +12,7 @@ Note: You need to have your test_generations.txt before you start this process.
cd
$HOME
git clone git@github.com:moses-smt/mosesdecoder.git
cd
mosesdecoder
git@github.com:rsennrich/wmt16-scripts.git
git clone
git@github.com:rsennrich/wmt16-scripts.git
```
(2) define a function for post processing.
...
...
examples/seq2seq/run_distributed_eval.py
View file @
33d479d2
import
argparse
import
shutil
import
time
from
json
import
JSONDecodeError
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
torch.utils.data
import
DataLoader
...
...
@@ -13,12 +16,29 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger
=
getLogger
(
__name__
)
try
:
from
.utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
from
.utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
)
except
ImportError
:
from
utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
from
utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
)
def
eval_data_dir
(
...
...
@@ -30,7 +50,6 @@ def eval_data_dir(
type_path
=
"val"
,
n_obs
=
None
,
fp16
=
False
,
save_source
=
False
,
num_beams
:
int
=
4
,
task
=
"summarization"
,
local_rank
=
None
,
...
...
@@ -62,7 +81,7 @@ def eval_data_dir(
n_obs
=
n_obs
,
prefix
=
model
.
config
.
prefix
,
)
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
)
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
,
add_extra_examples
=
False
)
data_loader
=
DataLoader
(
ds
,
sampler
=
sampler
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
)
dec_kwargs
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
# tokenizer.decode
results
=
[]
...
...
@@ -75,23 +94,19 @@ def eval_data_dir(
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
**
dec_kwargs
)
labels
=
tokenizer
.
batch_decode
(
batch
[
"labels"
],
**
dec_kwargs
)
if
save_source
:
docs
=
tokenizer
.
batch_decode
(
batch
[
"input_ids"
],
**
dec_kwargs
)
ids
=
batch
[
"ids"
]
for
i
in
range
(
len
(
labels
)):
label
,
pred
=
labels
[
i
],
preds
[
i
]
if
save_source
:
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
,
source
=
docs
[
i
]))
else
:
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
))
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
,
id
=
ids
[
i
].
item
()))
save_json
(
results
,
save_path
)
return
results
return
results
,
sampler
.
num_replicas
def
run_generate
():
parser
=
argparse
.
ArgumentParser
(
epilog
=
"Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
)
parser
.
add_argument
(
"--
input_path
"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
"--
data_dir
"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
...
...
@@ -113,17 +128,31 @@ def run_generate():
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"How many observations. Defaults to all."
)
parser
.
add_argument
(
"--sync_timeout"
,
type
=
int
,
default
=
600
,
required
=
False
,
help
=
"How long should master process wait for other processes to finish."
,
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--
save_source
"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--
debug
"
,
action
=
"store_true"
)
start_time
=
time
.
time
()
args
,
rest
=
parser
.
parse_known_args
()
generate_kwargs
=
parse_numeric_cl_kwargs
(
rest
)
if
generate_kwargs
:
print
(
f
"parsed the following generate kwargs:
{
generate_kwargs
}
"
)
json_save_dir
=
Path
(
args
.
save_dir
+
"_tmp"
)
Path
(
json_save_dir
).
mkdir
(
exist_ok
=
True
)
# this handles locking.
intermediate_files
=
list
(
json_save_dir
.
glob
(
"rank_*.json"
))
if
intermediate_files
:
raise
ValueError
(
f
"Found files at
{
json_save_dir
}
please move or remove them."
)
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
Path
(
args
.
save_dir
).
mkdir
(
exist_ok
=
True
)
eval_data_dir
(
args
.
input_path
,
args
.
save_dir
,
results
,
num_replicas
=
eval_data_dir
(
args
.
data_dir
,
json_
save_dir
,
args
.
model_name
,
type_path
=
args
.
type_path
,
batch_size
=
args
.
bs
,
...
...
@@ -131,11 +160,64 @@ def run_generate():
task
=
args
.
task
,
local_rank
=
args
.
local_rank
,
n_obs
=
args
.
n_obs
,
save_source
=
args
.
save_source
,
max_source_length
=
args
.
max_source_length
,
**
generate_kwargs
,
)
if
args
.
local_rank
<=
0
:
save_dir
=
Path
(
args
.
save_dir
)
save_dir
.
mkdir
(
exist_ok
=
True
)
partial_results
=
gather_results_from_each_node
(
num_replicas
,
json_save_dir
,
args
.
sync_timeout
)
preds
,
labels
=
combine_partial_results
(
partial_results
)
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu
=
"translation"
in
args
.
task
score_fn
=
calculate_bleu
if
calc_bleu
else
calculate_rouge
metric_name
=
"bleu"
if
calc_bleu
else
"rouge"
metrics
:
Dict
=
score_fn
(
preds
,
labels
)
metrics
[
"n_obs"
]
=
len
(
preds
)
runtime
=
time
.
time
()
-
start_time
metrics
[
"seconds_per_sample"
]
=
round
(
runtime
/
metrics
[
"n_obs"
],
2
)
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path
=
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_
{
metric_name
}
.json"
)
save_json
(
metrics
,
metrics_save_path
)
print
(
metrics
)
write_txt_file
(
preds
,
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_generations.txt"
))
if
args
.
debug
:
write_txt_file
(
labels
,
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
.target"
))
else
:
shutil
.
rmtree
(
json_save_dir
)
def
combine_partial_results
(
partial_results
)
->
Tuple
[
List
,
List
]:
"""Concatenate partial results into one file, then sort it by id."""
records
=
[]
for
partial_result
in
partial_results
:
records
.
extend
(
partial_result
)
records
=
list
(
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
]))
preds
=
[
x
[
"pred"
]
for
x
in
records
]
labels
=
[
x
[
"label"
]
for
x
in
records
]
return
preds
,
labels
def
gather_results_from_each_node
(
num_replicas
,
save_dir
,
timeout
)
->
List
[
Dict
[
str
,
List
]]:
# WAIT FOR lots of .json files
start_wait
=
time
.
time
()
logger
.
info
(
"waiting for all nodes to finish"
)
json_data
=
None
while
(
time
.
time
()
-
start_wait
)
<
timeout
:
json_files
=
list
(
save_dir
.
glob
(
"rank_*.json"
))
if
len
(
json_files
)
<
num_replicas
:
continue
try
:
# make sure all json files are fully saved
json_data
=
lmap
(
load_json
,
json_files
)
return
json_data
except
JSONDecodeError
:
continue
else
:
raise
TimeoutError
(
"Rank 0 gave up on waiting for other processes"
)
# Unreachable
if
__name__
==
"__main__"
:
# Usage for MT:
...
...
examples/seq2seq/utils.py
View file @
33d479d2
...
...
@@ -18,6 +18,7 @@ from torch import nn
from
torch.utils.data
import
Dataset
,
Sampler
from
transformers
import
BartTokenizer
from
transformers.file_utils
import
cached_property
def
label_smoothed_nll_loss
(
lprobs
,
target
,
epsilon
,
ignore_index
=-
100
):
...
...
@@ -114,9 +115,9 @@ class AbstractSeq2SeqDataset(Dataset):
def
get_char_lens
(
data_file
):
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
):
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
,
**
kwargs
):
if
distributed
:
return
DistributedSortishSampler
(
self
,
batch_size
)
return
DistributedSortishSampler
(
self
,
batch_size
,
**
kwargs
)
else
:
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
...
...
@@ -171,14 +172,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
).
rstrip
(
"
\n
"
)
assert
source_line
,
f
"empty source line for index
{
index
}
"
assert
tgt_line
,
f
"empty tgt line for index
{
index
}
"
return
{
"tgt_texts"
:
tgt_line
,
"src_texts"
:
source_line
,
}
return
{
"tgt_texts"
:
tgt_line
,
"src_texts"
:
source_line
,
"id"
:
index
-
1
}
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Call prepare_seq2seq_batch."""
batch_encoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
batch_encoding
:
Dict
[
str
,
torch
.
Tensor
]
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
...
...
@@ -187,8 +185,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
max_target_length
=
self
.
max_target_length
,
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
)
return
batch_encoding
.
data
).
data
batch_encoding
[
"ids"
]
=
torch
.
tensor
([
x
[
"id"
]
for
x
in
batch
])
return
batch_encoding
class
SortishSampler
(
Sampler
):
...
...
@@ -226,7 +225,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class
DistributedSortishSampler
(
Sampler
):
"""Copied from torch DistributedSampler"""
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
):
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
add_extra_examples
=
True
):
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
...
...
@@ -239,22 +238,27 @@ class DistributedSortishSampler(Sampler):
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
self
.
epoch
=
0
if
add_extra_examples
:
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
num_replicas
))
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
else
:
self
.
total_size
=
len
(
dataset
)
self
.
num_samples
=
len
(
self
.
available_indices
)
self
.
batch_size
=
batch_size
self
.
add_extra_examples
=
add_extra_examples
def
__iter__
(
self
)
->
Iterable
:
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
available_indices
=
self
.
get_indices_for_rank
()
# indices[self.rank: self.total_size: self.num_replicas]
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
available_indices
]
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
self
.
available_indices
]
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
)
indices
=
[
available_indices
[
i
]
for
i
in
sortish_indices
]
indices
=
[
self
.
available_indices
[
i
]
for
i
in
sortish_indices
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
get_indices_for_rank
(
self
)
->
np
.
array
:
@
cached_property
def
available_indices
(
self
)
->
np
.
array
:
indices
=
list
(
range
(
len
(
self
.
dataset
)))
# add extra samples to make it evenly divisible
indices
+=
indices
[:
(
self
.
total_size
-
len
(
indices
))]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment