run_distributed_eval.py 8.08 KB
Newer Older
1
import argparse
2
3
4
import shutil
import time
from json import JSONDecodeError
5
6
from logging import getLogger
from pathlib import Path
7
from typing import Dict, List
8
9
10
11
12
13
14
15
16
17
18

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


logger = getLogger(__name__)

try:
19
20
21
22
23
24
    from .utils import (
        Seq2SeqDataset,
        calculate_bleu,
        calculate_rouge,
        lmap,
        load_json,
25
        parse_numeric_n_bool_cl_kwargs,
26
27
28
29
        save_json,
        use_task_specific_params,
        write_txt_file,
    )
30
except ImportError:
31
32
33
34
35
36
    from utils import (
        Seq2SeqDataset,
        calculate_bleu,
        calculate_rouge,
        lmap,
        load_json,
37
        parse_numeric_n_bool_cl_kwargs,
38
39
40
41
        save_json,
        use_task_specific_params,
        write_txt_file,
    )
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71


def eval_data_dir(
    data_dir,
    save_dir: str,
    model_name: str,
    bs: int = 8,
    max_source_length: int = 1024,
    type_path="val",
    n_obs=None,
    fp16=False,
    task="summarization",
    local_rank=None,
    **generate_kwargs,
) -> Dict:
    """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
    model_name = str(model_name)
    assert local_rank is not None
    torch.distributed.init_process_group(backend="nccl", rank=local_rank)

    save_dir = Path(save_dir)
    save_path = save_dir.joinpath(f"rank_{local_rank}_output.json")
    torch.cuda.set_device(local_rank)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
    if fp16:
        model = model.half()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    logger.info(f"Inferred tokenizer type: {tokenizer.__class__}")  # if this is wrong, check config.model_type.
    use_task_specific_params(model, task)  # update config with task specific params
72
73
    if max_source_length is None:
        max_source_length = tokenizer.model_max_length
74
75
76
77
78
79
80
81
82
    ds = Seq2SeqDataset(
        tokenizer,
        data_dir,
        max_source_length,
        max_target_length=1024,
        type_path=type_path,
        n_obs=n_obs,
        prefix=model.config.prefix,
    )
83
84
85
    # I set shuffle=True for a more accurate progress bar.
    # If all the longest samples are first, the prog bar estimate is too high at the beginning.
    sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True)
86
87
88
89
90
91
92
93
    data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
    results = []
    for batch in tqdm(data_loader):
        summaries = model.generate(
            input_ids=batch["input_ids"].to(model.device),
            attention_mask=batch["attention_mask"].to(model.device),
            **generate_kwargs,
        )
94
        preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
95
        ids = batch["ids"]
96
97
        for i, pred in enumerate(preds):
            results.append(dict(pred=pred, id=ids[i].item()))
98
    save_json(results, save_path)
99
    return results, sampler.num_replicas
100
101
102
103
104
105


def run_generate():
    parser = argparse.ArgumentParser(
        epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
    )
106
    parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source")
107
108
109
110
111
112
113
    parser.add_argument(
        "--model_name",
        type=str,
        help="like facebook/bart-large-cnn,t5-base, etc.",
        default="sshleifer/distilbart-xsum-12-3",
    )
    parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
114
115
116
117
    parser.add_argument("--max_source_length", type=int, default=None)
    parser.add_argument(
        "--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
    )
118
119
120
121
122
123
124
125
126
127
    parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
    parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
    parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
    parser.add_argument(
        "--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch"
    )

    parser.add_argument(
        "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
    )
128
129
130
131
132
133
134
    parser.add_argument(
        "--sync_timeout",
        type=int,
        default=600,
        required=False,
        help="How long should master process wait for other processes to finish.",
    )
135
    parser.add_argument("--fp16", action="store_true")
136
137
    parser.add_argument("--debug", action="store_true")
    start_time = time.time()
138
    args, rest = parser.parse_known_args()
139
140
    generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest)
    if generate_kwargs and args.local_rank <= 0:
141
        print(f"parsed the following generate kwargs: {generate_kwargs}")
142
143
144
145
146
147
148
    json_save_dir = Path(args.save_dir + "_tmp")
    Path(json_save_dir).mkdir(exist_ok=True)  # this handles locking.
    intermediate_files = list(json_save_dir.glob("rank_*.json"))
    if intermediate_files:
        raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
        # In theory, a node could finish and save before another node hits this. If this happens, we can address later.

149
    Path(args.save_dir).mkdir(exist_ok=True)
150
151
152
    results, num_replicas = eval_data_dir(
        args.data_dir,
        json_save_dir,
153
        args.model_name,
154
        type_path=args.type_path,
Sam Shleifer's avatar
Sam Shleifer committed
155
        bs=args.bs,
156
157
158
159
        fp16=args.fp16,
        task=args.task,
        local_rank=args.local_rank,
        n_obs=args.n_obs,
160
161
        max_source_length=args.max_source_length,
        **generate_kwargs,
162
163
    )

164
165
166
167
    if args.local_rank <= 0:
        save_dir = Path(args.save_dir)
        save_dir.mkdir(exist_ok=True)
        partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
168
169
170
171
        preds = combine_partial_results(partial_results)
        tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
        labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]

172
173
174
175
176
177
178
179
180
181
        # Calculate metrics, save metrics,  and save _generations.txt
        calc_bleu = "translation" in args.task
        score_fn = calculate_bleu if calc_bleu else calculate_rouge
        metric_name = "bleu" if calc_bleu else "rouge"
        metrics: Dict = score_fn(preds, labels)
        metrics["n_obs"] = len(preds)
        runtime = time.time() - start_time
        metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
        # TODO(@stas00): add whatever metadata to metrics
        metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
182
        save_json(metrics, metrics_save_path, indent=None)
183
184
185
186
187
188
189
190
        print(metrics)
        write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
        if args.debug:
            write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target"))
        else:
            shutil.rmtree(json_save_dir)


191
def combine_partial_results(partial_results) -> List:
192
193
194
195
196
197
    """Concatenate partial results into one file, then sort it by id."""
    records = []
    for partial_result in partial_results:
        records.extend(partial_result)
    records = list(sorted(records, key=lambda x: x["id"]))
    preds = [x["pred"] for x in records]
198
    return preds
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219


def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
    # WAIT FOR lots of .json files
    start_wait = time.time()
    logger.info("waiting for all nodes to finish")
    json_data = None
    while (time.time() - start_wait) < timeout:
        json_files = list(save_dir.glob("rank_*.json"))
        if len(json_files) < num_replicas:
            continue
        try:
            # make sure all json files are fully saved
            json_data = lmap(load_json, json_files)
            return json_data
        except JSONDecodeError:
            continue
    else:
        raise TimeoutError("Rank 0 gave up on waiting for other processes")
    # Unreachable

220
221
222
223

if __name__ == "__main__":
    # Usage for MT:
    run_generate()