inference.py 9.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import argparse
import copy
import os
from typing import Dict, List

import torch
import torch.distributed as dist
from colossal_eval import dataset, models, utils

import colossalai
11
from colossalai.cluster import ProcessGroupMesh
12
from colossalai.logging import get_dist_logger
13
from colossalai.shardformer import ShardConfig
14
15
16
17

logger = get_dist_logger()


18
def rm_and_merge(dp_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
19
20
21
22
    """
    Remove inference result per rank and merge them into one file.

    Args:
23
        dp_size: Number of groups for data parallel.
24
25
26
27
28
29
30
31
32
33
34
35
36
        save_path: The folder for storing inference results.
        model_names: Names of models for inference.
        dataset_names: Names of dataset for inference.

    """

    for model_name in model_names:
        for dataset_name, categories in dataset_names.items():
            all_answers = {}
            for category in categories:
                all_answers[category] = {"data": []}
                answers = {"data": []}

37
                for r in range(dp_size):
38
                    directory = os.path.join(
39
                        save_path, model_name, f"{dataset_name}_{category}_inference_results_dp_rank{r}.json"
40
41
42
43
44
45
46
47
48
49
                    )
                    if not os.path.exists(directory):
                        raise Exception(
                            f"Directory {directory} not found. There may be an error during inference time."
                        )
                    else:
                        rank_answers = utils.jload(directory)
                        answers["data"].extend(rank_answers["data"])
                        answers["inference_kwargs"] = rank_answers["inference_kwargs"]

50
                for r in range(dp_size):
51
52
                    try:
                        directory = os.path.join(
53
                            save_path, model_name, f"{dataset_name}_{category}_inference_results_dp_rank{r}.json"
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
                        )
                        os.remove(directory)
                    except Exception as e:
                        print(e)

                all_answers[category] = answers

            logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
            utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))

        logger.info(f"Save inference results of model {model_name} for all dataset.")
    logger.info(f"Save inference results of all models for all dataset.")


def main(args):
    colossalai.launch_from_torch(config={}, seed=42)
    world_size = dist.get_world_size()
71

72
    rank = dist.get_rank()
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    DP_AXIS = 0
    TP_AXIS = 1

    dp_size = world_size // args.tp_size

    if rank == 0:
        logger.info("Setting TP and DP...")
        logger.info(f"TP size: {args.tp_size}, DP size: {dp_size}")

    if world_size % args.tp_size != 0:
        raise Exception(
            f"TP size is {args.tp_size} while world size is {world_size}! Please make sure world size is a multiple of TP size!"
        )

    pg_mesh = ProcessGroupMesh(dp_size, args.tp_size)
    tp_group = pg_mesh.get_group_along_axis(TP_AXIS)

    coordinates = pg_mesh._coord
    dp_rank = coordinates[DP_AXIS]
    tp_rank = coordinates[TP_AXIS]

    shard_config = (
        ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1)
        if args.tp_size > 1
        else None
    )
99
100
101
102

    inference_data = {}
    debug_args = {}
    few_shot_args = {}
Yuanchen's avatar
Yuanchen committed
103
    multiturn_args = {}
104
105
106
107
108
109
110
111
112
113
114
115

    config = utils.jload(args.config)

    model_parameters = config["model"]
    dataset_parameters = config["dataset"]

    for dataset_parameter in dataset_parameters:
        path = dataset_parameter["path"]
        save_path = dataset_parameter["save_path"]
        dataset_name = dataset_parameter["name"]
        debug_args[dataset_name] = dataset_parameter["debug"]
        few_shot_args[dataset_name] = dataset_parameter["few_shot"]
116
117
118
        forward_only = dataset_parameter.get("forward_only", False)
        load_train = dataset_parameter.get("load_train", False)
        load_reference = dataset_parameter.get("load_reference", False)
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

        if not args.load_dataset:
            if os.path.exists(save_path):
                dataset_ = utils.jload(save_path)
                inference_data[dataset_name] = dataset_["test"]
            else:
                raise Exception(
                    "Can't find the converted dataset. You may set load_dataset True to store the dataset first."
                )

            continue

        dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
        if not issubclass(dataset_class, dataset.BaseDataset):
            raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")

135
        dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"], forward_only, load_train, load_reference)
136
137

        dataset_.save(save_path)
Yuanchen's avatar
Yuanchen committed
138
139
140
141
142
143
144

        if hasattr(dataset_, "multiturn") and dataset_.multiturn:
            multiturn_args[dataset_name] = True
            logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.")
        else:
            multiturn_args[dataset_name] = False

145
146
        inference_data[dataset_name] = dataset_.dataset["test"]

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        if load_train and "train" in dataset_.dataset:
            new_dataset_name = f"{dataset_name}_train"
            debug_args[new_dataset_name] = dataset_parameter["debug"]
            few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
            inference_data[new_dataset_name] = dataset_.dataset["train"]

        if load_reference and "reference" in dataset_.dataset:
            new_dataset_name = f"{dataset_name}_reference"
            debug_args[new_dataset_name] = dataset_parameter["debug"]
            few_shot_args[new_dataset_name] = dataset_parameter["few_shot"]
            inference_data[new_dataset_name] = dataset_.dataset["reference"]

    if rank == 0:
        logger.info(f"Dataset for inference are: {list(inference_data.keys())}")

162
163
164
165
166
167
    for model_parameter in model_parameters:
        model_name = model_parameter["name"]
        model_class = eval(f"models.{model_parameter['model_class']}")
        paramerters = model_parameter["parameters"]
        paramerters.update({"logger": logger})
        paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
168
        paramerters.update({"shard_config": shard_config})
169
170
171
172
173
174
175

        model_ = model_class(**paramerters)
        if not issubclass(model_class, models.BaseModel):
            raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")

        for dataset_name, split_data in inference_data.items():
            start = 0
Yuanchen's avatar
Yuanchen committed
176
            prev_questions = None
177
            for category, category_data in split_data.items():
Yuanchen's avatar
Yuanchen committed
178
179
                num_turn = category_data["inference_kwargs"].get("turns", 1)

180
181
182
183
                if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
                    raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")

                answers_to_dump = copy.deepcopy(category_data)
184
185
                partition_size = len(category_data["data"]) // dp_size
                redundant = len(category_data["data"]) % dp_size
186
187

                # Ensure that the amount of data for inference is as consistent as possible across different processes.
188
                lengths = [partition_size for _ in range(dp_size)]
189
                for j in range(redundant):
190
                    lengths[(j + start) % dp_size] += 1
191

192
                start = (start + redundant) % dp_size
193

Yuanchen's avatar
Yuanchen committed
194
195
                for turn in range(num_turn):
                    if turn == 0:
196
197
198
                        questions = category_data["data"][
                            sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank]
                        ]
Yuanchen's avatar
Yuanchen committed
199
200
                    else:
                        questions = prev_questions
201

Yuanchen's avatar
Yuanchen committed
202
203
204
205
                    answers_per_rank = model_.inference(
                        questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
                    )
                    prev_questions = answers_per_rank
206
207
208

                answers_to_dump["data"] = answers_per_rank

209
210
211
212
213
214
215
216
217
                if tp_rank == 0:
                    utils.jdump(
                        answers_to_dump,
                        os.path.join(
                            args.inference_save_path,
                            model_name,
                            f"{dataset_name}_{category}_inference_results_dp_rank{dp_rank}.json",
                        ),
                    )
218
219
220
221
222
223
224
225
226
227

        logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")

        del model_
        torch.cuda.empty_cache()

    dist.barrier()
    if rank == 0:
        model_names = [model_parameter["name"] for model_parameter in model_parameters]
        dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
228
        rm_and_merge(dp_size, args.inference_save_path, model_names, dataset_names)
229
230
231
232
233
234
235


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ColossalEval inference process.")
    parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
    parser.add_argument("--load_dataset", default=False, action="store_true")
    parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
236
    parser.add_argument("--tp_size", type=int, default=1, help="tensor parallel size, used for large model inference")
237
238
239
    args = parser.parse_args()

    main(args)