run_distributed_eval.py 5.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import argparse
from logging import getLogger
from pathlib import Path
from typing import Dict

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


logger = getLogger(__name__)

try:
    from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
except ImportError:
    from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params

20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def eval_data_dir(
    data_dir,
    save_dir: str,
    model_name: str,
    bs: int = 8,
    max_source_length: int = 1024,
    type_path="val",
    n_obs=None,
    fp16=False,
    save_source=False,
    num_beams: int = 4,
    task="summarization",
    local_rank=None,
    **generate_kwargs,
) -> Dict:
    """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
    model_name = str(model_name)
    assert local_rank is not None
    torch.distributed.init_process_group(backend="nccl", rank=local_rank)

    save_dir = Path(save_dir)
    save_path = save_dir.joinpath(f"rank_{local_rank}_output.json")
    torch.cuda.set_device(local_rank)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
    if fp16:
        model = model.half()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    logger.info(f"Inferred tokenizer type: {tokenizer.__class__}")  # if this is wrong, check config.model_type.
    use_task_specific_params(model, task)  # update config with task specific params
54
55
    if max_source_length is None:
        max_source_length = tokenizer.model_max_length
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    ds = Seq2SeqDataset(
        tokenizer,
        data_dir,
        max_source_length,
        max_target_length=1024,
        type_path=type_path,
        n_obs=n_obs,
        prefix=model.config.prefix,
    )
    sampler = ds.make_sortish_sampler(bs, distributed=True)
    data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
    dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False)  # tokenizer.decode
    results = []
    for batch in tqdm(data_loader):
        summaries = model.generate(
            input_ids=batch["input_ids"].to(model.device),
            attention_mask=batch["attention_mask"].to(model.device),
            num_beams=num_beams,
            **generate_kwargs,
        )
        preds = tokenizer.batch_decode(summaries, **dec_kwargs)
        labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
        if save_source:
            docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs)
        for i in range(len(labels)):
            label, pred = labels[i], preds[i]
            if save_source:
                results.append(dict(pred=pred, label=label, source=docs[i]))
            else:
                results.append(dict(pred=pred, label=label))
    save_json(results, save_path)
    return results


def run_generate():
    parser = argparse.ArgumentParser(
        epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
    )
    parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source")
    parser.add_argument(
        "--model_name",
        type=str,
        help="like facebook/bart-large-cnn,t5-base, etc.",
        default="sshleifer/distilbart-xsum-12-3",
    )
    parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
102
103
104
105
    parser.add_argument("--max_source_length", type=int, default=None)
    parser.add_argument(
        "--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
    )
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
    parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
    parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
    parser.add_argument(
        "--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch"
    )

    parser.add_argument(
        "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
    )
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--save_source", action="store_true")

    args, rest = parser.parse_known_args()
120
121
122
    generate_kwargs = parse_numeric_cl_kwargs(rest)
    if generate_kwargs:
        print(f"parsed the following generate kwargs: {generate_kwargs}")
123
124
125
126
127
    Path(args.save_dir).mkdir(exist_ok=True)
    eval_data_dir(
        args.input_path,
        args.save_dir,
        args.model_name,
128
        type_path=args.type_path,
129
130
131
132
133
134
        batch_size=args.bs,
        fp16=args.fp16,
        task=args.task,
        local_rank=args.local_rank,
        n_obs=args.n_obs,
        save_source=args.save_source,
135
136
        max_source_length=args.max_source_length,
        **generate_kwargs,
137
138
139
140
141
142
    )


if __name__ == "__main__":
    # Usage for MT:
    run_generate()