inference.py 6.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import argparse
import copy
import os
from typing import Dict, List

import torch
import torch.distributed as dist
from colossal_eval import dataset, models, utils

import colossalai
from colossalai.logging import get_dist_logger

logger = get_dist_logger()


def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
    """
    Remove inference result per rank and merge them into one file.

    Args:
        world_size: Number of processes for inference.
        save_path: The folder for storing inference results.
        model_names: Names of models for inference.
        dataset_names: Names of dataset for inference.

    """

    for model_name in model_names:
        for dataset_name, categories in dataset_names.items():
            all_answers = {}
            for category in categories:
                all_answers[category] = {"data": []}
                answers = {"data": []}

                for r in range(world_size):
                    directory = os.path.join(
                        save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
                    )
                    if not os.path.exists(directory):
                        raise Exception(
                            f"Directory {directory} not found. There may be an error during inference time."
                        )
                    else:
                        rank_answers = utils.jload(directory)
                        answers["data"].extend(rank_answers["data"])
                        answers["inference_kwargs"] = rank_answers["inference_kwargs"]

                for r in range(world_size):
                    try:
                        directory = os.path.join(
                            save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
                        )
                        os.remove(directory)
                    except Exception as e:
                        print(e)

                all_answers[category] = answers

            logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
            utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))

        logger.info(f"Save inference results of model {model_name} for all dataset.")
    logger.info(f"Save inference results of all models for all dataset.")


def main(args):
    colossalai.launch_from_torch(config={}, seed=42)
    world_size = dist.get_world_size()
    rank = dist.get_rank()

    inference_data = {}
    debug_args = {}
    few_shot_args = {}

    config = utils.jload(args.config)

    model_parameters = config["model"]
    dataset_parameters = config["dataset"]

    for dataset_parameter in dataset_parameters:
        path = dataset_parameter["path"]
        save_path = dataset_parameter["save_path"]
        dataset_name = dataset_parameter["name"]
        debug_args[dataset_name] = dataset_parameter["debug"]
        few_shot_args[dataset_name] = dataset_parameter["few_shot"]

        if not args.load_dataset:
            if os.path.exists(save_path):
                dataset_ = utils.jload(save_path)
                inference_data[dataset_name] = dataset_["test"]
            else:
                raise Exception(
                    "Can't find the converted dataset. You may set load_dataset True to store the dataset first."
                )

            continue

        dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
        if not issubclass(dataset_class, dataset.BaseDataset):
            raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")

        dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])

        dataset_.save(save_path)
        inference_data[dataset_name] = dataset_.dataset["test"]

    for model_parameter in model_parameters:
        model_name = model_parameter["name"]
        model_class = eval(f"models.{model_parameter['model_class']}")
        paramerters = model_parameter["parameters"]
        paramerters.update({"logger": logger})
        paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})

        model_ = model_class(**paramerters)
        if not issubclass(model_class, models.BaseModel):
            raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")

        for dataset_name, split_data in inference_data.items():
            start = 0
            for category, category_data in split_data.items():
                if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
                    raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")

                answers_to_dump = copy.deepcopy(category_data)
                partition_size = len(category_data["data"]) // world_size
                redundant = len(category_data["data"]) % world_size

                # Ensure that the amount of data for inference is as consistent as possible across different processes.
                lengths = [partition_size for _ in range(world_size)]
                for j in range(redundant):
                    lengths[(j + start) % world_size] += 1

                start = (start + redundant) % world_size

                questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]

                answers_per_rank = model_.inference(
                    questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
                )

                answers_to_dump["data"] = answers_per_rank

                utils.jdump(
                    answers_to_dump,
                    os.path.join(
                        args.inference_save_path,
                        model_name,
                        f"{dataset_name}_{category}_inference_results_rank{rank}.json",
                    ),
                )

        logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")

        del model_
        torch.cuda.empty_cache()

    dist.barrier()
    if rank == 0:
        model_names = [model_parameter["name"] for model_parameter in model_parameters]
        dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
        rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ColossalEval inference process.")
    parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
    parser.add_argument("--load_dataset", default=False, action="store_true")
    parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
    args = parser.parse_args()

    main(args)