Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
33d479d2
Unverified
Commit
33d479d2
authored
Sep 14, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 14, 2020
Browse files
[s2s] distributed eval in one command (#7124)
parent
206b78d4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
85 deletions
+125
-85
examples/seq2seq/aggregate_distributed_results.py
examples/seq2seq/aggregate_distributed_results.py
+0
-46
examples/seq2seq/romanian_postprocessing.md
examples/seq2seq/romanian_postprocessing.md
+1
-1
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+104
-22
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+20
-16
No files found.
examples/seq2seq/aggregate_distributed_results.py
deleted
100644 → 0
View file @
206b78d4
from
pathlib
import
Path
import
fire
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
load_json
,
save_json
,
write_txt_file
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
load_json
,
save_json
,
write_txt_file
def
combine_partial_results
(
result_dir
:
str
,
save_dir
:
str
=
None
,
save_prefix
=
None
,
calc_bleu
=
False
,
just_metrics
=
False
):
"""Write first n lines of each file f in src_dir to dest_dir/f """
src_dir
=
Path
(
result_dir
)
save_dir
=
Path
(
save_dir
)
save_dir
.
mkdir
(
exist_ok
=
True
)
paths_to_combine
=
list
(
src_dir
.
glob
(
"rank*.json"
))
records
=
[]
for
partial_result
in
paths_to_combine
:
records
.
extend
(
load_json
(
partial_result
))
preds
=
[
x
[
"pred"
]
for
x
in
records
]
labels
=
[
x
[
"label"
]
for
x
in
records
]
score_fn
=
calculate_bleu
if
calc_bleu
else
calculate_rouge
metrics
=
score_fn
(
preds
,
labels
)
save_json
(
metrics
,
save_dir
.
joinpath
(
"metrics.json"
))
# better would be be {prefix}_{rouge|bleu}.json
print
(
metrics
)
if
just_metrics
:
return
if
save_prefix
is
None
:
save_prefix
=
"generated"
print
(
"using generated as prefix"
)
tgt_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.target"
)
write_txt_file
(
labels
,
tgt_path
)
pred_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.pred_target"
)
write_txt_file
(
preds
,
pred_path
)
if
"source"
in
records
[
0
]:
src_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.source"
)
write_txt_file
([
x
[
"source"
]
for
x
in
records
],
src_path
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
combine_partial_results
)
examples/seq2seq/romanian_postprocessing.md
View file @
33d479d2
...
@@ -12,7 +12,7 @@ Note: You need to have your test_generations.txt before you start this process.
...
@@ -12,7 +12,7 @@ Note: You need to have your test_generations.txt before you start this process.
cd
$HOME
cd
$HOME
git clone git@github.com:moses-smt/mosesdecoder.git
git clone git@github.com:moses-smt/mosesdecoder.git
cd
mosesdecoder
cd
mosesdecoder
git@github.com:rsennrich/wmt16-scripts.git
git clone
git@github.com:rsennrich/wmt16-scripts.git
```
```
(2) define a function for post processing.
(2) define a function for post processing.
...
...
examples/seq2seq/run_distributed_eval.py
View file @
33d479d2
import
argparse
import
argparse
import
shutil
import
time
from
json
import
JSONDecodeError
from
logging
import
getLogger
from
logging
import
getLogger
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
@@ -13,12 +16,29 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
...
@@ -13,12 +16,29 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger
=
getLogger
(
__name__
)
logger
=
getLogger
(
__name__
)
try
:
try
:
from
.utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
from
.utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
)
except
ImportError
:
except
ImportError
:
from
utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
from
utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
)
def
eval_data_dir
(
def
eval_data_dir
(
...
@@ -30,7 +50,6 @@ def eval_data_dir(
...
@@ -30,7 +50,6 @@ def eval_data_dir(
type_path
=
"val"
,
type_path
=
"val"
,
n_obs
=
None
,
n_obs
=
None
,
fp16
=
False
,
fp16
=
False
,
save_source
=
False
,
num_beams
:
int
=
4
,
num_beams
:
int
=
4
,
task
=
"summarization"
,
task
=
"summarization"
,
local_rank
=
None
,
local_rank
=
None
,
...
@@ -62,7 +81,7 @@ def eval_data_dir(
...
@@ -62,7 +81,7 @@ def eval_data_dir(
n_obs
=
n_obs
,
n_obs
=
n_obs
,
prefix
=
model
.
config
.
prefix
,
prefix
=
model
.
config
.
prefix
,
)
)
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
)
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
,
add_extra_examples
=
False
)
data_loader
=
DataLoader
(
ds
,
sampler
=
sampler
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
)
data_loader
=
DataLoader
(
ds
,
sampler
=
sampler
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
)
dec_kwargs
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
# tokenizer.decode
dec_kwargs
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
# tokenizer.decode
results
=
[]
results
=
[]
...
@@ -75,23 +94,19 @@ def eval_data_dir(
...
@@ -75,23 +94,19 @@ def eval_data_dir(
)
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
**
dec_kwargs
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
**
dec_kwargs
)
labels
=
tokenizer
.
batch_decode
(
batch
[
"labels"
],
**
dec_kwargs
)
labels
=
tokenizer
.
batch_decode
(
batch
[
"labels"
],
**
dec_kwargs
)
if
save_source
:
ids
=
batch
[
"ids"
]
docs
=
tokenizer
.
batch_decode
(
batch
[
"input_ids"
],
**
dec_kwargs
)
for
i
in
range
(
len
(
labels
)):
for
i
in
range
(
len
(
labels
)):
label
,
pred
=
labels
[
i
],
preds
[
i
]
label
,
pred
=
labels
[
i
],
preds
[
i
]
if
save_source
:
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
,
id
=
ids
[
i
].
item
()))
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
,
source
=
docs
[
i
]))
else
:
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
))
save_json
(
results
,
save_path
)
save_json
(
results
,
save_path
)
return
results
return
results
,
sampler
.
num_replicas
def
run_generate
():
def
run_generate
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
epilog
=
"Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
epilog
=
"Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
)
)
parser
.
add_argument
(
"--
input_path
"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
"--
data_dir
"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_name"
,
"--model_name"
,
type
=
str
,
type
=
str
,
...
@@ -113,17 +128,31 @@ def run_generate():
...
@@ -113,17 +128,31 @@ def run_generate():
parser
.
add_argument
(
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"How many observations. Defaults to all."
"--n_obs"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"How many observations. Defaults to all."
)
)
parser
.
add_argument
(
"--sync_timeout"
,
type
=
int
,
default
=
600
,
required
=
False
,
help
=
"How long should master process wait for other processes to finish."
,
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--
save_source
"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--
debug
"
,
action
=
"store_true"
)
start_time
=
time
.
time
()
args
,
rest
=
parser
.
parse_known_args
()
args
,
rest
=
parser
.
parse_known_args
()
generate_kwargs
=
parse_numeric_cl_kwargs
(
rest
)
generate_kwargs
=
parse_numeric_cl_kwargs
(
rest
)
if
generate_kwargs
:
if
generate_kwargs
:
print
(
f
"parsed the following generate kwargs:
{
generate_kwargs
}
"
)
print
(
f
"parsed the following generate kwargs:
{
generate_kwargs
}
"
)
json_save_dir
=
Path
(
args
.
save_dir
+
"_tmp"
)
Path
(
json_save_dir
).
mkdir
(
exist_ok
=
True
)
# this handles locking.
intermediate_files
=
list
(
json_save_dir
.
glob
(
"rank_*.json"
))
if
intermediate_files
:
raise
ValueError
(
f
"Found files at
{
json_save_dir
}
please move or remove them."
)
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
Path
(
args
.
save_dir
).
mkdir
(
exist_ok
=
True
)
Path
(
args
.
save_dir
).
mkdir
(
exist_ok
=
True
)
eval_data_dir
(
results
,
num_replicas
=
eval_data_dir
(
args
.
input_path
,
args
.
data_dir
,
args
.
save_dir
,
json_
save_dir
,
args
.
model_name
,
args
.
model_name
,
type_path
=
args
.
type_path
,
type_path
=
args
.
type_path
,
batch_size
=
args
.
bs
,
batch_size
=
args
.
bs
,
...
@@ -131,11 +160,64 @@ def run_generate():
...
@@ -131,11 +160,64 @@ def run_generate():
task
=
args
.
task
,
task
=
args
.
task
,
local_rank
=
args
.
local_rank
,
local_rank
=
args
.
local_rank
,
n_obs
=
args
.
n_obs
,
n_obs
=
args
.
n_obs
,
save_source
=
args
.
save_source
,
max_source_length
=
args
.
max_source_length
,
max_source_length
=
args
.
max_source_length
,
**
generate_kwargs
,
**
generate_kwargs
,
)
)
if
args
.
local_rank
<=
0
:
save_dir
=
Path
(
args
.
save_dir
)
save_dir
.
mkdir
(
exist_ok
=
True
)
partial_results
=
gather_results_from_each_node
(
num_replicas
,
json_save_dir
,
args
.
sync_timeout
)
preds
,
labels
=
combine_partial_results
(
partial_results
)
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu
=
"translation"
in
args
.
task
score_fn
=
calculate_bleu
if
calc_bleu
else
calculate_rouge
metric_name
=
"bleu"
if
calc_bleu
else
"rouge"
metrics
:
Dict
=
score_fn
(
preds
,
labels
)
metrics
[
"n_obs"
]
=
len
(
preds
)
runtime
=
time
.
time
()
-
start_time
metrics
[
"seconds_per_sample"
]
=
round
(
runtime
/
metrics
[
"n_obs"
],
2
)
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path
=
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_
{
metric_name
}
.json"
)
save_json
(
metrics
,
metrics_save_path
)
print
(
metrics
)
write_txt_file
(
preds
,
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
_generations.txt"
))
if
args
.
debug
:
write_txt_file
(
labels
,
save_dir
.
joinpath
(
f
"
{
args
.
type_path
}
.target"
))
else
:
shutil
.
rmtree
(
json_save_dir
)
def
combine_partial_results
(
partial_results
)
->
Tuple
[
List
,
List
]:
"""Concatenate partial results into one file, then sort it by id."""
records
=
[]
for
partial_result
in
partial_results
:
records
.
extend
(
partial_result
)
records
=
list
(
sorted
(
records
,
key
=
lambda
x
:
x
[
"id"
]))
preds
=
[
x
[
"pred"
]
for
x
in
records
]
labels
=
[
x
[
"label"
]
for
x
in
records
]
return
preds
,
labels
def
gather_results_from_each_node
(
num_replicas
,
save_dir
,
timeout
)
->
List
[
Dict
[
str
,
List
]]:
# WAIT FOR lots of .json files
start_wait
=
time
.
time
()
logger
.
info
(
"waiting for all nodes to finish"
)
json_data
=
None
while
(
time
.
time
()
-
start_wait
)
<
timeout
:
json_files
=
list
(
save_dir
.
glob
(
"rank_*.json"
))
if
len
(
json_files
)
<
num_replicas
:
continue
try
:
# make sure all json files are fully saved
json_data
=
lmap
(
load_json
,
json_files
)
return
json_data
except
JSONDecodeError
:
continue
else
:
raise
TimeoutError
(
"Rank 0 gave up on waiting for other processes"
)
# Unreachable
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# Usage for MT:
# Usage for MT:
...
...
examples/seq2seq/utils.py
View file @
33d479d2
...
@@ -18,6 +18,7 @@ from torch import nn
...
@@ -18,6 +18,7 @@ from torch import nn
from
torch.utils.data
import
Dataset
,
Sampler
from
torch.utils.data
import
Dataset
,
Sampler
from
transformers
import
BartTokenizer
from
transformers
import
BartTokenizer
from
transformers.file_utils
import
cached_property
def
label_smoothed_nll_loss
(
lprobs
,
target
,
epsilon
,
ignore_index
=-
100
):
def
label_smoothed_nll_loss
(
lprobs
,
target
,
epsilon
,
ignore_index
=-
100
):
...
@@ -114,9 +115,9 @@ class AbstractSeq2SeqDataset(Dataset):
...
@@ -114,9 +115,9 @@ class AbstractSeq2SeqDataset(Dataset):
def
get_char_lens
(
data_file
):
def
get_char_lens
(
data_file
):
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
return
[
len
(
x
)
for
x
in
Path
(
data_file
).
open
().
readlines
()]
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
):
def
make_sortish_sampler
(
self
,
batch_size
,
distributed
=
False
,
**
kwargs
):
if
distributed
:
if
distributed
:
return
DistributedSortishSampler
(
self
,
batch_size
)
return
DistributedSortishSampler
(
self
,
batch_size
,
**
kwargs
)
else
:
else
:
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
return
SortishSampler
(
self
.
src_lens
,
batch_size
)
...
@@ -171,14 +172,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -171,14 +172,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
).
rstrip
(
"
\n
"
)
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
).
rstrip
(
"
\n
"
)
assert
source_line
,
f
"empty source line for index
{
index
}
"
assert
source_line
,
f
"empty source line for index
{
index
}
"
assert
tgt_line
,
f
"empty tgt line for index
{
index
}
"
assert
tgt_line
,
f
"empty tgt line for index
{
index
}
"
return
{
return
{
"tgt_texts"
:
tgt_line
,
"src_texts"
:
source_line
,
"id"
:
index
-
1
}
"tgt_texts"
:
tgt_line
,
"src_texts"
:
source_line
,
}
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Call prepare_seq2seq_batch."""
"""Call prepare_seq2seq_batch."""
batch_encoding
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
batch_encoding
:
Dict
[
str
,
torch
.
Tensor
]
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
[
x
[
"src_texts"
]
for
x
in
batch
],
[
x
[
"src_texts"
]
for
x
in
batch
],
src_lang
=
self
.
src_lang
,
src_lang
=
self
.
src_lang
,
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
tgt_texts
=
[
x
[
"tgt_texts"
]
for
x
in
batch
],
...
@@ -187,8 +185,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
...
@@ -187,8 +185,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
max_target_length
=
self
.
max_target_length
,
max_target_length
=
self
.
max_target_length
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
add_prefix_space
=
self
.
add_prefix_space
,
add_prefix_space
=
self
.
add_prefix_space
,
)
).
data
return
batch_encoding
.
data
batch_encoding
[
"ids"
]
=
torch
.
tensor
([
x
[
"id"
]
for
x
in
batch
])
return
batch_encoding
class
SortishSampler
(
Sampler
):
class
SortishSampler
(
Sampler
):
...
@@ -226,7 +225,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
...
@@ -226,7 +225,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class
DistributedSortishSampler
(
Sampler
):
class
DistributedSortishSampler
(
Sampler
):
"""Copied from torch DistributedSampler"""
"""Copied from torch DistributedSampler"""
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
):
def
__init__
(
self
,
dataset
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
add_extra_examples
=
True
):
if
num_replicas
is
None
:
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
raise
RuntimeError
(
"Requires distributed package to be available"
)
...
@@ -239,22 +238,27 @@ class DistributedSortishSampler(Sampler):
...
@@ -239,22 +238,27 @@ class DistributedSortishSampler(Sampler):
self
.
num_replicas
=
num_replicas
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
self
.
rank
=
rank
self
.
epoch
=
0
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
num_replicas
))
if
add_extra_examples
:
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
num_replicas
))
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
else
:
self
.
total_size
=
len
(
dataset
)
self
.
num_samples
=
len
(
self
.
available_indices
)
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
add_extra_examples
=
add_extra_examples
def
__iter__
(
self
)
->
Iterable
:
def
__iter__
(
self
)
->
Iterable
:
g
=
torch
.
Generator
()
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
g
.
manual_seed
(
self
.
epoch
)
available_indices
=
self
.
get_indices_for_rank
()
# indices[self.rank: self.total_size: self.num_replicas]
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
available_indices
]
sortish_data
=
[
self
.
dataset
.
src_lens
[
i
]
for
i
in
self
.
available_indices
]
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
)
sortish_indices
=
sortish_sampler_indices
(
sortish_data
,
self
.
batch_size
)
indices
=
[
available_indices
[
i
]
for
i
in
sortish_indices
]
indices
=
[
self
.
available_indices
[
i
]
for
i
in
sortish_indices
]
assert
len
(
indices
)
==
self
.
num_samples
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
return
iter
(
indices
)
def
get_indices_for_rank
(
self
)
->
np
.
array
:
@
cached_property
def
available_indices
(
self
)
->
np
.
array
:
indices
=
list
(
range
(
len
(
self
.
dataset
)))
indices
=
list
(
range
(
len
(
self
.
dataset
)))
# add extra samples to make it evenly divisible
# add extra samples to make it evenly divisible
indices
+=
indices
[:
(
self
.
total_size
-
len
(
indices
))]
indices
+=
indices
[:
(
self
.
total_size
-
len
(
indices
))]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment