run_distributed_eval.py 8.37 KB
Newer Older
1
import argparse
2
3
4
import shutil
import time
from json import JSONDecodeError
5
6
from logging import getLogger
from pathlib import Path
7
from typing import Dict, List
8
9
10
11
12
13

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
14
15
16
17
18
19
20
21
22
23
24
from utils import (
    Seq2SeqDataset,
    calculate_bleu,
    calculate_rouge,
    lmap,
    load_json,
    parse_numeric_n_bool_cl_kwargs,
    save_json,
    use_task_specific_params,
    write_txt_file,
)
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


logger = getLogger(__name__)


def eval_data_dir(
    data_dir,
    save_dir: str,
    model_name: str,
    bs: int = 8,
    max_source_length: int = 1024,
    type_path="val",
    n_obs=None,
    fp16=False,
    task="summarization",
    local_rank=None,
41
42
43
    src_lang=None,
    tgt_lang=None,
    prefix="",
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    **generate_kwargs,
) -> Dict:
    """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
    model_name = str(model_name)
    assert local_rank is not None
    torch.distributed.init_process_group(backend="nccl", rank=local_rank)

    save_dir = Path(save_dir)
    save_path = save_dir.joinpath(f"rank_{local_rank}_output.json")
    torch.cuda.set_device(local_rank)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
    if fp16:
        model = model.half()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    logger.info(f"Inferred tokenizer type: {tokenizer.__class__}")  # if this is wrong, check config.model_type.
    use_task_specific_params(model, task)  # update config with task specific params
61
62
    if max_source_length is None:
        max_source_length = tokenizer.model_max_length
63
64
    if prefix is None:
        prefix = prefix or getattr(model.config, "prefix", "") or ""
65
66
67
68
69
70
71
    ds = Seq2SeqDataset(
        tokenizer,
        data_dir,
        max_source_length,
        max_target_length=1024,
        type_path=type_path,
        n_obs=n_obs,
72
73
74
        src_lang=src_lang,
        tgt_lang=tgt_lang,
        prefix=prefix,
75
    )
76
77
78
    # I set shuffle=True for a more accurate progress bar.
    # If all the longest samples are first, the prog bar estimate is too high at the beginning.
    sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True)
79
80
81
82
83
84
85
86
    data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
    results = []
    for batch in tqdm(data_loader):
        summaries = model.generate(
            input_ids=batch["input_ids"].to(model.device),
            attention_mask=batch["attention_mask"].to(model.device),
            **generate_kwargs,
        )
87
        preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
88
        ids = batch["ids"]
89
90
        for i, pred in enumerate(preds):
            results.append(dict(pred=pred, id=ids[i].item()))
91
    save_json(results, save_path)
92
    return results, sampler.num_replicas
93
94
95
96
97
98


def run_generate():
    parser = argparse.ArgumentParser(
        epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
    )
99
    parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source")
100
101
102
103
104
105
106
    parser.add_argument(
        "--model_name",
        type=str,
        help="like facebook/bart-large-cnn,t5-base, etc.",
        default="sshleifer/distilbart-xsum-12-3",
    )
    parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
107
108
109
110
    parser.add_argument("--max_source_length", type=int, default=None)
    parser.add_argument(
        "--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
    )
111
112
113
114
115
116
117
118
119
120
    parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
    parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
    parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
    parser.add_argument(
        "--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch"
    )

    parser.add_argument(
        "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
    )
121
122
123
124
125
126
127
    parser.add_argument(
        "--sync_timeout",
        type=int,
        default=600,
        required=False,
        help="How long should master process wait for other processes to finish.",
    )
128
129
130
131
132
    parser.add_argument("--src_lang", type=str, default=None, required=False)
    parser.add_argument("--tgt_lang", type=str, default=None, required=False)
    parser.add_argument(
        "--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
    )
133
    parser.add_argument("--fp16", action="store_true")
134
135
    parser.add_argument("--debug", action="store_true")
    start_time = time.time()
136
    args, rest = parser.parse_known_args()
137
138
    generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest)
    if generate_kwargs and args.local_rank <= 0:
139
        print(f"parsed the following generate kwargs: {generate_kwargs}")
140
141
142
143
144
145
146
    json_save_dir = Path(args.save_dir + "_tmp")
    Path(json_save_dir).mkdir(exist_ok=True)  # this handles locking.
    intermediate_files = list(json_save_dir.glob("rank_*.json"))
    if intermediate_files:
        raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
        # In theory, a node could finish and save before another node hits this. If this happens, we can address later.

147
    Path(args.save_dir).mkdir(exist_ok=True)
148
149
150
    results, num_replicas = eval_data_dir(
        args.data_dir,
        json_save_dir,
151
        args.model_name,
152
        type_path=args.type_path,
Sam Shleifer's avatar
Sam Shleifer committed
153
        bs=args.bs,
154
155
156
157
        fp16=args.fp16,
        task=args.task,
        local_rank=args.local_rank,
        n_obs=args.n_obs,
158
        max_source_length=args.max_source_length,
159
160
161
        prefix=args.prefix,
        src_lang=args.src_lang,
        tgt_lang=args.tgt_lang,
162
        **generate_kwargs,
163
164
    )

165
166
167
168
    if args.local_rank <= 0:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(exist_ok=True)
        partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
169
170
171
172
        preds = combine_partial_results(partial_results)
        tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
        labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]

173
174
175
176
177
178
179
        # Calculate metrics, save metrics,  and save _generations.txt
        calc_bleu = "translation" in args.task
        score_fn = calculate_bleu if calc_bleu else calculate_rouge
        metric_name = "bleu" if calc_bleu else "rouge"
        metrics: Dict = score_fn(preds, labels)
        metrics["n_obs"] = len(preds)
        runtime = time.time() - start_time
180
181
        metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4)
        metrics["n_gpus"] = num_replicas
182
183
        # TODO(@stas00): add whatever metadata to metrics
        metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
184
        save_json(metrics, metrics_save_path, indent=None)
185
186
187
188
189
190
191
192
        print(metrics)
        write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
        if args.debug:
            write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target"))
        else:
            shutil.rmtree(json_save_dir)


193
def combine_partial_results(partial_results) -> List:
194
195
196
197
198
199
    """Concatenate partial results into one file, then sort it by id."""
    records = []
    for partial_result in partial_results:
        records.extend(partial_result)
    records = list(sorted(records, key=lambda x: x["id"]))
    preds = [x["pred"] for x in records]
200
    return preds
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221


def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
    # WAIT FOR lots of .json files
    start_wait = time.time()
    logger.info("waiting for all nodes to finish")
    json_data = None
    while (time.time() - start_wait) < timeout:
        json_files = list(save_dir.glob("rank_*.json"))
        if len(json_files) < num_replicas:
            continue
        try:
            # make sure all json files are fully saved
            json_data = lmap(load_json, json_files)
            return json_data
        except JSONDecodeError:
            continue
    else:
        raise TimeoutError("Rank 0 gave up on waiting for other processes")
    # Unreachable

222
223
224
225

if __name__ == "__main__":
    # Usage for MT:
    run_generate()