Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e7f8d2ab
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ecfddc6034a12c3ef6c4ef3e6f56f7d034ec3075"
Unverified
Commit
e7f8d2ab
authored
Sep 13, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 13, 2020
Browse files
[s2s] two stage run_distributed_eval.py (#7105)
parent
0ec63afe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
192 additions
and
0 deletions
+192
-0
examples/seq2seq/aggregate_distributed_results.py
examples/seq2seq/aggregate_distributed_results.py
+46
-0
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+139
-0
examples/seq2seq/utils.py
examples/seq2seq/utils.py
+7
-0
No files found.
examples/seq2seq/aggregate_distributed_results.py
0 → 100644
View file @
e7f8d2ab
from
pathlib
import
Path
import
fire
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
load_json
,
save_json
,
write_txt_file
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
load_json
,
save_json
,
write_txt_file
def
combine_partial_results
(
result_dir
:
str
,
save_dir
:
str
=
None
,
save_prefix
=
None
,
calc_bleu
=
False
,
just_metrics
=
False
):
"""Write first n lines of each file f in src_dir to dest_dir/f """
src_dir
=
Path
(
result_dir
)
save_dir
=
Path
(
save_dir
)
save_dir
.
mkdir
(
exist_ok
=
True
)
paths_to_combine
=
list
(
src_dir
.
glob
(
"rank*.json"
))
records
=
[]
for
partial_result
in
paths_to_combine
:
records
.
extend
(
load_json
(
partial_result
))
preds
=
[
x
[
"pred"
]
for
x
in
records
]
labels
=
[
x
[
"label"
]
for
x
in
records
]
score_fn
=
calculate_bleu
if
calc_bleu
else
calculate_rouge
metrics
=
score_fn
(
preds
,
labels
)
save_json
(
metrics
,
save_dir
.
joinpath
(
"metrics.json"
))
# better would be be {prefix}_{rouge|bleu}.json
print
(
metrics
)
if
just_metrics
:
return
if
save_prefix
is
None
:
save_prefix
=
"generated"
print
(
"using generated as prefix"
)
tgt_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.target"
)
write_txt_file
(
labels
,
tgt_path
)
pred_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.pred_target"
)
write_txt_file
(
preds
,
pred_path
)
if
"source"
in
records
[
0
]:
src_path
=
save_dir
.
joinpath
(
f
"
{
save_prefix
}
.source"
)
write_txt_file
([
x
[
"source"
]
for
x
in
records
],
src_path
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
combine_partial_results
)
examples/seq2seq/run_distributed_eval.py
0 → 100644
View file @
e7f8d2ab
import
argparse
import
warnings
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Dict
import
torch
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
logger
=
getLogger
(
__name__
)
try
:
from
.utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
except
ImportError
:
from
utils
import
Seq2SeqDataset
,
parse_numeric_cl_kwargs
,
save_json
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
eval_data_dir
(
data_dir
,
save_dir
:
str
,
model_name
:
str
,
bs
:
int
=
8
,
max_source_length
:
int
=
1024
,
type_path
=
"val"
,
n_obs
=
None
,
fp16
=
False
,
save_source
=
False
,
num_beams
:
int
=
4
,
task
=
"summarization"
,
local_rank
=
None
,
**
generate_kwargs
,
)
->
Dict
:
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
model_name
=
str
(
model_name
)
assert
local_rank
is
not
None
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
rank
=
local_rank
)
save_dir
=
Path
(
save_dir
)
save_path
=
save_dir
.
joinpath
(
f
"rank_
{
local_rank
}
_output.json"
)
torch
.
cuda
.
set_device
(
local_rank
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
cuda
()
if
fp16
:
model
=
model
.
half
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
logger
.
info
(
f
"Inferred tokenizer type:
{
tokenizer
.
__class__
}
"
)
# if this is wrong, check config.model_type.
use_task_specific_params
(
model
,
task
)
# update config with task specific params
ds
=
Seq2SeqDataset
(
tokenizer
,
data_dir
,
max_source_length
,
max_target_length
=
1024
,
type_path
=
type_path
,
n_obs
=
n_obs
,
prefix
=
model
.
config
.
prefix
,
)
sampler
=
ds
.
make_sortish_sampler
(
bs
,
distributed
=
True
)
data_loader
=
DataLoader
(
ds
,
sampler
=
sampler
,
batch_size
=
bs
,
collate_fn
=
ds
.
collate_fn
)
dec_kwargs
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
# tokenizer.decode
results
=
[]
for
batch
in
tqdm
(
data_loader
):
summaries
=
model
.
generate
(
input_ids
=
batch
[
"input_ids"
].
to
(
model
.
device
),
attention_mask
=
batch
[
"attention_mask"
].
to
(
model
.
device
),
num_beams
=
num_beams
,
**
generate_kwargs
,
)
preds
=
tokenizer
.
batch_decode
(
summaries
,
**
dec_kwargs
)
labels
=
tokenizer
.
batch_decode
(
batch
[
"labels"
],
**
dec_kwargs
)
if
save_source
:
docs
=
tokenizer
.
batch_decode
(
batch
[
"input_ids"
],
**
dec_kwargs
)
for
i
in
range
(
len
(
labels
)):
label
,
pred
=
labels
[
i
],
preds
[
i
]
if
save_source
:
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
,
source
=
docs
[
i
]))
else
:
results
.
append
(
dict
(
pred
=
pred
,
label
=
label
))
save_json
(
results
,
save_path
)
return
results
def
run_generate
():
parser
=
argparse
.
ArgumentParser
(
epilog
=
"Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
)
parser
.
add_argument
(
"--input_path"
,
type
=
str
,
help
=
"like cnn_dm/test.source"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
help
=
"like facebook/bart-large-cnn,t5-base, etc."
,
default
=
"sshleifer/distilbart-xsum-12-3"
,
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
help
=
"where to save"
,
default
=
"tmp_gen"
)
parser
.
add_argument
(
"--prefix"
,
type
=
str
,
default
=
"test"
,
help
=
"which subset to evaluate typically train/val/test"
)
parser
.
add_argument
(
"--reference_path"
,
type
=
str
,
required
=
False
,
help
=
"like cnn_dm/test.target"
)
parser
.
add_argument
(
"--score_path"
,
type
=
str
,
required
=
False
,
default
=
"metrics.json"
,
help
=
"where to save metrics"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
default
=
"summarization"
,
help
=
"used for task_specific_params + metrics"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
"batch size"
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
,
required
=
False
,
help
=
"should be passed by distributed.launch"
)
parser
.
add_argument
(
"--n_obs"
,
type
=
int
,
default
=
None
,
required
=
False
,
help
=
"How many observations. Defaults to all."
)
parser
.
add_argument
(
"--fp16"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--save_source"
,
action
=
"store_true"
)
args
,
rest
=
parser
.
parse_known_args
()
parsed
=
parse_numeric_cl_kwargs
(
rest
)
if
parsed
:
print
(
f
"parsed the following generate kwargs:
{
parsed
}
"
)
Path
(
args
.
save_dir
).
mkdir
(
exist_ok
=
True
)
if
args
.
reference_path
is
None
and
Path
(
args
.
score_path
).
exists
():
warnings
.
warn
(
f
"score_path
{
args
.
score_path
}
will be overwritten unless you type ctrl-c."
)
eval_data_dir
(
args
.
input_path
,
args
.
save_dir
,
args
.
model_name
,
prefix
=
args
.
prefix
,
batch_size
=
args
.
bs
,
fp16
=
args
.
fp16
,
task
=
args
.
task
,
local_rank
=
args
.
local_rank
,
n_obs
=
args
.
n_obs
,
save_source
=
args
.
save_source
,
**
parsed
,
)
if
__name__
==
"__main__"
:
# Usage for MT:
run_generate
()
examples/seq2seq/utils.py
View file @
e7f8d2ab
...
@@ -387,3 +387,10 @@ def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, fl
...
@@ -387,3 +387,10 @@ def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, fl
result
[
unparsed_args
[
i
][
2
:]]
=
value
result
[
unparsed_args
[
i
][
2
:]]
=
value
return
result
return
result
def
write_txt_file
(
ordered_tgt
,
path
):
f
=
Path
(
path
).
open
(
"w"
)
for
ln
in
ordered_tgt
:
f
.
write
(
ln
+
"
\n
"
)
f
.
flush
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment