inference.py 7.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import copy
import os
from typing import Dict, List

import torch
import torch.distributed as dist
from colossal_eval import dataset, models, utils

import colossalai
from colossalai.logging import get_dist_logger

logger = get_dist_logger()


def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
    """
    Remove inference result per rank and merge them into one file.

    Args:
        world_size: Number of processes for inference.
        save_path: The folder for storing inference results.
        model_names: Names of models for inference.
        dataset_names: Names of dataset for inference.

    """

    for model_name in model_names:
        for dataset_name, categories in dataset_names.items():
            all_answers = {}
            for category in categories:
                all_answers[category] = {"data": []}
                answers = {"data": []}

                for r in range(world_size):
                    directory = os.path.join(
                        save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
                    )
                    if not os.path.exists(directory):
                        raise Exception(
                            f"Directory {directory} not found. There may be an error during inference time."
                        )
                    else:
                        rank_answers = utils.jload(directory)
                        answers["data"].extend(rank_answers["data"])
                        answers["inference_kwargs"] = rank_answers["inference_kwargs"]

                for r in range(world_size):
                    try:
                        directory = os.path.join(
                            save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
                        )
                        os.remove(directory)
                    except Exception as e:
                        print(e)

                all_answers[category] = answers

            logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
            utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))

        logger.info(f"Save inference results of model {model_name} for all dataset.")
    logger.info(f"Save inference results of all models for all dataset.")


def main(args):
    colossalai.launch_from_torch(config={}, seed=42)
    world_size = dist.get_world_size()
    rank = dist.get_rank()

    inference_data = {}
    debug_args = {}
    few_shot_args = {}
Yuanchen's avatar
Yuanchen committed
74
    multiturn_args = {}
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    config = utils.jload(args.config)

    model_parameters = config["model"]
    dataset_parameters = config["dataset"]

    for dataset_parameter in dataset_parameters:
        path = dataset_parameter["path"]
        save_path = dataset_parameter["save_path"]
        dataset_name = dataset_parameter["name"]
        debug_args[dataset_name] = dataset_parameter["debug"]
        few_shot_args[dataset_name] = dataset_parameter["few_shot"]

        if not args.load_dataset:
            if os.path.exists(save_path):
                dataset_ = utils.jload(save_path)
                inference_data[dataset_name] = dataset_["test"]
            else:
                raise Exception(
                    "Can't find the converted dataset. You may set load_dataset True to store the dataset first."
                )

            continue

        dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
        if not issubclass(dataset_class, dataset.BaseDataset):
            raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")

        dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])

        dataset_.save(save_path)
Yuanchen's avatar
Yuanchen committed
106
107
108
109
110
111
112

        if hasattr(dataset_, "multiturn") and dataset_.multiturn:
            multiturn_args[dataset_name] = True
            logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.")
        else:
            multiturn_args[dataset_name] = False

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        inference_data[dataset_name] = dataset_.dataset["test"]

    for model_parameter in model_parameters:
        model_name = model_parameter["name"]
        model_class = eval(f"models.{model_parameter['model_class']}")
        paramerters = model_parameter["parameters"]
        paramerters.update({"logger": logger})
        paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})

        model_ = model_class(**paramerters)
        if not issubclass(model_class, models.BaseModel):
            raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")

        for dataset_name, split_data in inference_data.items():
            start = 0
Yuanchen's avatar
Yuanchen committed
128
            prev_questions = None
129
            for category, category_data in split_data.items():
Yuanchen's avatar
Yuanchen committed
130
131
                num_turn = category_data["inference_kwargs"].get("turns", 1)

132
133
134
135
136
137
138
139
140
141
142
143
144
145
                if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
                    raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")

                answers_to_dump = copy.deepcopy(category_data)
                partition_size = len(category_data["data"]) // world_size
                redundant = len(category_data["data"]) % world_size

                # Ensure that the amount of data for inference is as consistent as possible across different processes.
                lengths = [partition_size for _ in range(world_size)]
                for j in range(redundant):
                    lengths[(j + start) % world_size] += 1

                start = (start + redundant) % world_size

Yuanchen's avatar
Yuanchen committed
146
147
148
149
150
                for turn in range(num_turn):
                    if turn == 0:
                        questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
                    else:
                        questions = prev_questions
151

Yuanchen's avatar
Yuanchen committed
152
153
154
155
                    answers_per_rank = model_.inference(
                        questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
                    )
                    prev_questions = answers_per_rank
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

                answers_to_dump["data"] = answers_per_rank

                utils.jdump(
                    answers_to_dump,
                    os.path.join(
                        args.inference_save_path,
                        model_name,
                        f"{dataset_name}_{category}_inference_results_rank{rank}.json",
                    ),
                )

        logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")

        del model_
        torch.cuda.empty_cache()

    dist.barrier()
    if rank == 0:
        model_names = [model_parameter["name"] for model_parameter in model_parameters]
        dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
        rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ColossalEval inference process.")
    parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
    parser.add_argument("--load_dataset", default=False, action="store_true")
    parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
    args = parser.parse_args()

    main(args)