evaluator.py 30.7 KB
Newer Older
Baber Abbasi's avatar
Baber Abbasi committed
1
import itertools
2
import json
3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
import random
5
import time
6
from collections import defaultdict
artemorloff's avatar
artemorloff committed
7
from typing import TYPE_CHECKING, Optional, Union
Baber Abbasi's avatar
Baber Abbasi committed
8

9
import numpy as np
Baber Abbasi's avatar
Baber Abbasi committed
10
import torch
lintangsutawika's avatar
lintangsutawika committed
11

lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
Lintang Sutawika's avatar
Lintang Sutawika committed
14
import lm_eval.api.task
Baber Abbasi's avatar
Baber Abbasi committed
15
import lm_eval.models
artemorloff's avatar
artemorloff committed
16
from lm_eval.api.eval_config import EvaluationConfig
17
from lm_eval.caching.cache import delete_cache
18
from lm_eval.evaluator_utils import (
Lintang Sutawika's avatar
Lintang Sutawika committed
19
    consolidate_group_results,
20
21
    consolidate_results,
    get_sample_size,
Lintang Sutawika's avatar
Lintang Sutawika committed
22
    get_subtask_list,
23
24
25
26
27
    get_task_list,
    prepare_print_tasks,
    print_writeout,
    run_task_tests,
)
KonradSzafer's avatar
KonradSzafer committed
28
from lm_eval.loggers import EvaluationTracker
29
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
30
from lm_eval.tasks import TaskManager, get_task_dict
31
32
33
34
from lm_eval.utils import (
    handle_non_serializable,
    hash_string,
    positional_deprecated,
Baber Abbasi's avatar
Baber Abbasi committed
35
    setup_logging,
36
37
    simple_parse_args_string,
)
38

Fabrizio Milo's avatar
Fabrizio Milo committed
39

40
41
if TYPE_CHECKING:
    from lm_eval.api.model import LM
Lintang Sutawika's avatar
Lintang Sutawika committed
42
    from lm_eval.api.task import Task
43

Lintang Sutawika's avatar
Lintang Sutawika committed
44
45
eval_logger = logging.getLogger(__name__)

46

47
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
48
def simple_evaluate(
artemorloff's avatar
artemorloff committed
49
50
    config: "EvaluationConfig",
    # TODO: bootstrap_iters is not passed from cli_evaluate
Ethan Smith's avatar
Ethan Smith committed
51
    bootstrap_iters: int = 100000,
KonradSzafer's avatar
KonradSzafer committed
52
    evaluation_tracker: Optional[EvaluationTracker] = None,
53
    task_manager: Optional[TaskManager] = None,
Fabrizio Milo's avatar
Fabrizio Milo committed
54
):
55
    """Instantiate and evaluate a model on a list of tasks.
56

57
58
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
59
60
    :param model_args: Optional[str, dict]
        String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
61
        Ignored if `model` argument is a LM object.
62
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
63
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
64
65
    :param num_fewshot: int
        Number of examples in few-shot context
66
    :param batch_size: int or str, optional
67
        Batch size for model
68
69
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
70
    :param device: str, optional
71
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
72
73
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
74
75
76
    :param cache_requests: bool, optional
        Speed up evaluation by caching the building of dataset requests. `None` if not caching.
    :param rewrite_requests_cache: bool, optional
Baber Abbasi's avatar
Baber Abbasi committed
77
        Rewrites all the request cache if set to `True`. `None` if not desired.
78
    :param delete_requests_cache: bool, optional
Baber Abbasi's avatar
Baber Abbasi committed
79
        Deletes all the request cache if set to `True`. `None` if not desired.
80
81
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
82
83
    :param samples: dictionary, optional
        Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
84
    :param bootstrap_iters:
85
        Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
Stephen Hogg's avatar
Stephen Hogg committed
86
87
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
88
    :param write_out: bool
89
90
91
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
KonradSzafer's avatar
KonradSzafer committed
92
93
    :param system_instruction: str
        System instruction to be applied to the prompt
94
95
96
97
98
    :param apply_chat_template: Union[bool, str]
        Specifies whether to apply a chat template to the prompt.
        - If set to True, the default chat template is applied.
        - If set to a string, applies the specified chat template by name.
        Defaults to False (no chat template applied).
KonradSzafer's avatar
KonradSzafer committed
99
100
    :param fewshot_as_multiturn: bool
        Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
Baber Abbasi's avatar
Baber Abbasi committed
101
102
    :param gen_kwargs: dict or comma-separated string
        Arguments for model generation
103
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
104
    :param verbosity: str
Lintang Sutawika's avatar
Lintang Sutawika committed
105
        Verbosity level for logging
Baber Abbasi's avatar
Baber Abbasi committed
106
107
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated
108
109
110
111
112
113
    :param random_seed: int
        Random seed for python's random module. If set to None, the seed will not be set.
    :param numpy_random_seed: int
        Random seed for numpy. If set to None, the seed will not be set.
    :param torch_random_seed: int
        Random seed for torch. If set to None, the seed will not be set.
114
115
    :param fewshot_random_seed: int
        Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
Baber Abbasi's avatar
Baber Abbasi committed
116
117
    :param metadata: dict
        Additional metadata to be added to the task manager. Will get passed to the download function of the task.
Baber Abbasi's avatar
Baber Abbasi committed
118

Baber Abbasi's avatar
Baber Abbasi committed
119
    return
120
        Dictionary of results
121
    """
artemorloff's avatar
artemorloff committed
122
123
    if config.verbosity is not None:
        setup_logging(verbosity=config.verbosity)
124
    start_date = time.time()
125

artemorloff's avatar
artemorloff committed
126
    if config.limit is not None and config.samples is not None:
127
128
129
130
        raise ValueError(
            "Either 'limit' or 'samples' must be None, but both are not None."
        )

artemorloff's avatar
artemorloff committed
131
132
    if isinstance(config.model_args, str) and (
        "instruct" in config.model_args and not config.apply_chat_template
Baber Abbasi's avatar
Baber Abbasi committed
133
134
135
136
137
    ):
        eval_logger.warning(
            "Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
        )

artemorloff's avatar
artemorloff committed
138
    if config.request_caching_args.get("delete_requests_cache", False):
139
140
141
        eval_logger.info("Deleting requests cache...")
        delete_cache()

142
    seed_message = []
artemorloff's avatar
artemorloff committed
143
    if config.seed[0] is not None:
144
        # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
artemorloff's avatar
artemorloff committed
145
146
        seed_message.append(f"Setting random seed to {config.seed[0]}")
        random.seed(config.seed[0])
147

artemorloff's avatar
artemorloff committed
148
149
150
    if config.seed[1] is not None:
        seed_message.append(f"Setting numpy seed to {config.seed[1]}")
        np.random.seed(config.seed[1])
151

artemorloff's avatar
artemorloff committed
152
153
154
    if config.seed[2] is not None:
        seed_message.append(f"Setting torch manual seed to {config.seed[2]}")
        torch.manual_seed(config.seed[2])
155

artemorloff's avatar
artemorloff committed
156
157
    if config.seed[3] is not None:
        seed_message.append(f"Setting fewshot manual seed to {config.seed[3]}")
158

159
160
161
    if seed_message:
        eval_logger.info(" | ".join(seed_message))

artemorloff's avatar
artemorloff committed
162
163
164
    if config.tasks is None:
        config.tasks = []
    if len(config.tasks) == 0:
165
166
167
        raise ValueError(
            "No tasks specified, or no tasks found. Please verify the task names."
        )
168

artemorloff's avatar
artemorloff committed
169
170
171
    if config.gen_kwargs is not None:
        if isinstance(config.gen_kwargs, str):
            config.gen_kwargs = simple_parse_args_string(config.gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
172
        eval_logger.warning(
artemorloff's avatar
artemorloff committed
173
            f"generation_kwargs: {config.gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
174
            "Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
175
        )
artemorloff's avatar
artemorloff committed
176
177
        if not config.gen_kwargs:
            config.gen_kwargs = None
lintangsutawika's avatar
lintangsutawika committed
178

artemorloff's avatar
artemorloff committed
179
180
    if isinstance(config.model, str):
        if config.model_args is None:
181
            eval_logger.warning("model_args not specified. Using defaults.")
artemorloff's avatar
artemorloff committed
182
            config.model_args = ""
183

artemorloff's avatar
artemorloff committed
184
        if isinstance(config.model_args, dict):
185
            eval_logger.info(
artemorloff's avatar
artemorloff committed
186
                f"Initializing {config.model} model, with arguments: {config.model_args}"
187
            )
artemorloff's avatar
artemorloff committed
188
189
            lm = lm_eval.api.registry.get_model(config.model).create_from_arg_obj(
                config.model_args,
190
                {
artemorloff's avatar
artemorloff committed
191
192
193
                    "batch_size": config.batch_size,
                    "max_batch_size": config.max_batch_size,
                    "device": config.device,
194
195
196
197
                },
            )

        else:
198
            eval_logger.info(
artemorloff's avatar
artemorloff committed
199
                f"Initializing {config.model} model, with arguments: {simple_parse_args_string(config.model_args)}"
200
            )
artemorloff's avatar
artemorloff committed
201
202
            lm = lm_eval.api.registry.get_model(config.model).create_from_arg_string(
                config.model_args,
203
                {
artemorloff's avatar
artemorloff committed
204
205
206
                    "batch_size": config.batch_size,
                    "max_batch_size": config.max_batch_size,
                    "device": config.device,
207
208
                },
            )
209
    else:
artemorloff's avatar
artemorloff committed
210
        if not isinstance(config.model, lm_eval.api.model.LM):
211
            raise TypeError(
artemorloff's avatar
artemorloff committed
212
                f"The value of `model` passed to simple_evaluate() was of type {type(config.model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
213
            )
214
        eval_logger.info("Using pre-initialized model")
artemorloff's avatar
artemorloff committed
215
        lm = config.model
216

artemorloff's avatar
artemorloff committed
217
    if config.use_cache is not None:
artemorloff's avatar
artemorloff committed
218
219
220
        eval_logger.info(
            f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}"
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
221
222
        lm = lm_eval.api.model.CachingLM(
            lm,
artemorloff's avatar
artemorloff committed
223
            config.use_cache
haileyschoelkopf's avatar
haileyschoelkopf committed
224
225
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
226
227
228
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
229
230
        )

231
    if task_manager is None:
artemorloff's avatar
artemorloff committed
232
        task_manager = TaskManager(metadata=config.metadata)
233

Baber Abbasi's avatar
Baber Abbasi committed
234
    task_dict = get_task_dict(
artemorloff's avatar
artemorloff committed
235
        config.tasks,
Baber Abbasi's avatar
Baber Abbasi committed
236
237
        task_manager,
    )
Baber Abbasi's avatar
Baber Abbasi committed
238

Lintang Sutawika's avatar
Lintang Sutawika committed
239
240
241
242
243
244
245
246
247
248
    # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
    # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
    def _adjust_config(task_dict):
        adjusted_task_dict = {}
        for task_name, task_obj in task_dict.items():
            if isinstance(task_obj, dict):
                adjusted_task_dict = {
                    **adjusted_task_dict,
                    **{task_name: _adjust_config(task_obj)},
                }
249

250
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
251
                if task_obj.get_config("output_type") == "generate_until":
artemorloff's avatar
artemorloff committed
252
                    if config.gen_kwargs is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
253
                        task_obj.set_config(
artemorloff's avatar
artemorloff committed
254
255
256
                            key="generation_kwargs",
                            value=config.gen_kwargs,
                            update=True,
Lintang Sutawika's avatar
Lintang Sutawika committed
257
                        )
Baber Abbasi's avatar
Baber Abbasi committed
258
259
260
                    eval_logger.info(
                        f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
                    )
Lintang Sutawika's avatar
Lintang Sutawika committed
261

artemorloff's avatar
artemorloff committed
262
                if config.predict_only:
Lintang Sutawika's avatar
Lintang Sutawika committed
263
264
265
266
267
268
269
270
                    eval_logger.info(
                        f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                    )
                    # we have to change the class properties post-hoc. This is pretty hacky.
                    task_obj.override_metric(metric_name="bypass")

                # override tasks' fewshot values to the provided num_fewshot arg value
                # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
artemorloff's avatar
artemorloff committed
271
                if config.num_fewshot is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
272
273
274
275
276
277
                    if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
                        eval_logger.info(
                            f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                        )
                    else:
                        eval_logger.warning(
artemorloff's avatar
artemorloff committed
278
                            f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {config.num_fewshot}"
Lintang Sutawika's avatar
Lintang Sutawika committed
279
                        )
artemorloff's avatar
artemorloff committed
280
                        task_obj.set_config(key="num_fewshot", value=config.num_fewshot)
Lintang Sutawika's avatar
Lintang Sutawika committed
281
282
283
284
285
286
287
                else:
                    # if num_fewshot not provided, and the task does not define a default one, default to 0
                    if (
                        default_num_fewshot := task_obj.get_config("num_fewshot")
                    ) is None:
                        task_obj.set_config(key="num_fewshot", value=0)
                # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
artemorloff's avatar
artemorloff committed
288
                task_obj.set_fewshot_seed(seed=config.seed[3])
Lintang Sutawika's avatar
Lintang Sutawika committed
289
290
291
292
293
294

                adjusted_task_dict[task_name] = task_obj

        return adjusted_task_dict

    task_dict = _adjust_config(task_dict)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
295

artemorloff's avatar
artemorloff committed
296
297
    if config.check_integrity:
        run_task_tests(task_list=config.tasks)
Stephen Hogg's avatar
Stephen Hogg committed
298

KonradSzafer's avatar
KonradSzafer committed
299
300
    if evaluation_tracker is not None:
        evaluation_tracker.general_config_tracker.log_experiment_args(
artemorloff's avatar
artemorloff committed
301
302
303
304
305
            model_source=config.model,
            model_args=config.model_args,
            system_instruction=config.system_instruction,
            chat_template=lm.chat_template(config.apply_chat_template)
            if config.apply_chat_template
Baber Abbasi's avatar
Baber Abbasi committed
306
            else None,
artemorloff's avatar
artemorloff committed
307
            fewshot_as_multiturn=config.fewshot_as_multiturn,
KonradSzafer's avatar
KonradSzafer committed
308
309
        )

310
311
312
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
artemorloff's avatar
artemorloff committed
313
314
315
        limit=config.limit,
        samples=config.samples,
        cache_requests=config.cache_requests,
artemorloff's avatar
artemorloff committed
316
317
318
        rewrite_requests_cache=config.request_caching_args.get(
            "rewrite_requests_cache", False
        ),
Niklas Muennighoff's avatar
Niklas Muennighoff committed
319
        bootstrap_iters=bootstrap_iters,
artemorloff's avatar
artemorloff committed
320
321
322
323
324
325
326
        write_out=config.write_out,
        log_samples=True if config.predict_only else config.log_samples,
        system_instruction=config.system_instruction,
        apply_chat_template=config.apply_chat_template,
        fewshot_as_multiturn=config.fewshot_as_multiturn,
        verbosity=config.verbosity,
        confirm_run_unsafe_code=config.confirm_run_unsafe_code,
327
    )
artemorloff's avatar
artemorloff committed
328
329
    if config.verbosity is not None:
        setup_logging(verbosity=config.verbosity)
330

331
    if lm.rank == 0:
artemorloff's avatar
artemorloff committed
332
333
        if isinstance(config.model, str):
            model_name = config.model
artemorloff's avatar
artemorloff committed
334
335
336
        elif hasattr(config.model, "config") and hasattr(
            config.model.config, "_name_or_path"
        ):
artemorloff's avatar
artemorloff committed
337
            model_name = config.model.config._name_or_path
338
        else:
artemorloff's avatar
artemorloff committed
339
            model_name = type(config.model).__name__
340

341
342
        # add info about the model and few shot config
        results["config"] = {
343
            "model": model_name,
artemorloff's avatar
artemorloff committed
344
            "model_args": config.model_args,
345
        }
346
347
348
349
350
351
        # add more detailed model info if available
        if isinstance(lm, lm_eval.models.huggingface.HFLM):
            results["config"].update(lm.get_model_info())
        # add info about execution
        results["config"].update(
            {
artemorloff's avatar
artemorloff committed
352
                "batch_size": config.batch_size,
353
354
355
                "batch_sizes": (
                    list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
                ),
artemorloff's avatar
artemorloff committed
356
357
358
                "device": config.device,
                "use_cache": config.use_cache,
                "limit": config.limit,
359
                "bootstrap_iters": bootstrap_iters,
artemorloff's avatar
artemorloff committed
360
361
362
363
364
                "gen_kwargs": config.gen_kwargs,
                "random_seed": config.seed[0],
                "numpy_seed": config.seed[1],
                "torch_seed": config.seed[2],
                "fewshot_seed": config.seed[3],
365
366
            }
        )
367
        results["git_hash"] = get_git_commit_hash()
368
        results["date"] = start_date
369
        add_env_info(results)  # additional environment info to results
achervyakov's avatar
achervyakov committed
370
        add_tokenizer_info(results, lm)  # additional info about tokenizer
371
372
373
        return results
    else:
        return None
374

Leo Gao's avatar
Leo Gao committed
375

376
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
377
def evaluate(
378
    lm: "LM",
Fabrizio Milo's avatar
Fabrizio Milo committed
379
    task_dict,
Baber Abbasi's avatar
Baber Abbasi committed
380
    limit: Optional[int] = None,
381
    samples: Optional[dict] = None,
382
383
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
384
    bootstrap_iters: Optional[int] = 100000,
Ethan Smith's avatar
Ethan Smith committed
385
386
    write_out: bool = False,
    log_samples: bool = True,
KonradSzafer's avatar
KonradSzafer committed
387
    system_instruction: Optional[str] = None,
388
    apply_chat_template: Union[bool, str] = False,
KonradSzafer's avatar
KonradSzafer committed
389
    fewshot_as_multiturn: bool = False,
390
    verbosity: str = "INFO",
Hojin Lee's avatar
Hojin Lee committed
391
    confirm_run_unsafe_code: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
392
):
393
394
395
396
397
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
398
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
399
400
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
401
402
    :param samples: dictionary, optional
        Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
Hojin Lee's avatar
Hojin Lee committed
403
404
405
406
    :param cache_requests: bool, optional
        Speed up evaluation by caching the building of dataset requests.
    :param rewrite_requests_cache: bool, optional
        Rewrites all the request cache if set to `True`.
407
    :param bootstrap_iters:
408
        Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
409
    :param write_out: bool
410
411
412
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
KonradSzafer's avatar
KonradSzafer committed
413
414
    :param system_instruction: str
        System instruction to be applied to the prompt
415
416
417
418
419
    :param apply_chat_template: Union[bool, str]
        Specifies whether to apply a chat template to the prompt.
        - If set to True, the default chat template is applied.
        - If set to a string, applies the specified chat template by name.
        Defaults to False (no chat template applied).
KonradSzafer's avatar
KonradSzafer committed
420
421
    :param fewshot_as_multiturn: bool
        Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
Hojin Lee's avatar
Hojin Lee committed
422
423
424
425
    :param verbosity: str
        Verbosity level for logging
    :param confirm_run_unsafe_code: bool
        Whether to confirm running tasks marked as unsafe.
426
427
428
    :return
        Dictionary of results
    """
429

430
431
432
433
434
435
    if limit is not None and samples is not None:
        raise ValueError(
            "Either 'limit' or 'samples' must be None, but both are not None."
        )
    if samples is not None:
        eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
436
437
438
439
    if apply_chat_template:
        eval_logger.warning(
            "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
        )
440
    # tracks all Instances/requests a model must generate output on.
441
    requests = defaultdict(list)
442
443
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
444
    padding_requests = defaultdict(int)
445

446
    # get lists of group hierarchy and each type of request
Lintang Sutawika's avatar
Lintang Sutawika committed
447
    eval_tasks = get_task_list(task_dict)
448
    if not log_samples:
449
        if not all(
450
451
            "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
            for task_output in eval_tasks
452
453
        ):
            raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
454

Hojin Lee's avatar
Hojin Lee committed
455
456
457
    # validation checks:
    # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
    # 2.are we running code that is marked as unsafe.
458
    incompatible_tasks = []
459
460
    for task_output in eval_tasks:
        task: Task = task_output.task
461
462
463

        if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
            incompatible_tasks.append(task_output.task_name)
Hojin Lee's avatar
Hojin Lee committed
464
465
466
467
        elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
            raise ValueError(
                f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
            )
468
469
470
471
472
473
474
475
476
    if len(incompatible_tasks) > 0:
        if not getattr(lm, "MULTIMODAL", False):
            raise ValueError(
                f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
            )
        else:
            raise ValueError(
                f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
            )
Hojin Lee's avatar
Hojin Lee committed
477
    # end validation check
478

Chenjie Luo's avatar
Chenjie Luo committed
479
480
481
    # Cache the limit arg.
    limit_arg = limit
    limits = []
482
483
484
    for task_output in eval_tasks:
        task: Task = task_output.task

Chenjie Luo's avatar
Chenjie Luo committed
485
486
        limit = get_sample_size(task, limit_arg)
        limits.append(limit)
487
488
        task.build_all_requests(
            limit=limit,
489
490
491
            samples=samples.get(task_output.task_name, None)
            if samples is not None
            else samples,
492
493
494
495
            rank=lm.rank,
            world_size=lm.world_size,
            cache_requests=cache_requests,
            rewrite_requests_cache=rewrite_requests_cache,
KonradSzafer's avatar
KonradSzafer committed
496
            system_instruction=system_instruction,
497
            apply_chat_template=bool(apply_chat_template),
KonradSzafer's avatar
KonradSzafer committed
498
            fewshot_as_multiturn=fewshot_as_multiturn,
499
500
501
502
503
504
            chat_template=getattr(lm, "apply_chat_template")
            if apply_chat_template
            else None,
            tokenizer_name=getattr(lm, "tokenizer_name", "")
            if apply_chat_template
            else "",
505
        )
506
        eval_logger.debug(
507
            f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
haileyschoelkopf's avatar
haileyschoelkopf committed
508
509
        )
        if write_out:
510
            print_writeout(task)
511
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
512
513
514
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
515
516

        if lm.world_size > 1:
517
518
519
520
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
521
522
523
524
525
526
            # "multiple_choice" task types dispatch (several) "loglikelihood" request types
            reqtype = (
                "loglikelihood"
                if task.OUTPUT_TYPE == "multiple_choice"
                else task.OUTPUT_TYPE
            )
527
            # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
528
            numpad = max(gathered_item) - gathered_item[lm.rank]
529
530
            # todo: may not account for padding in cases like SquadV2 which has multiple req types
            padding_requests[reqtype] += numpad
531

532
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
533
534
    # execute each type of request
    for reqtype, reqs in requests.items():
535
        eval_logger.info(f"Running {reqtype} requests")
536
537
538
539
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
540

541
542
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
543
544
                cloned_reqs.extend([req] * req.repeats)

545
546
547
548
549
550
551
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

552
553
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
554

555
556
    RANK = lm.rank
    WORLD_SIZE = lm.world_size
557
558
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
Chenjie Luo's avatar
Chenjie Luo committed
559
    for task_output, limit in zip(eval_tasks, limits):
560
        task = task_output.task
561
562
        task.apply_filters()

563
564
        ### Collect values of metrics on all datapoints ###
        # # unpack results and sort back in order and return control to Task
haileyschoelkopf's avatar
haileyschoelkopf committed
565
        # TODO: make it possible to use a different metric per filter
566
        # Pre-process task.instances to group by doc_id
567
        instances_by_doc_id = defaultdict(list)
568
569
570
571
572
        for instance in task.instances:
            instances_by_doc_id[instance.doc_id].append(instance)
        # Sort instances within each group
        for instances in instances_by_doc_id.values():
            instances.sort(key=lambda x: x.idx)
haileyschoelkopf's avatar
haileyschoelkopf committed
573
        # iterate over different filters used
574
        for filter_key in task.instances[0].filtered_resps.keys():
575
576
577
578
579
            indices = (
                samples.get(task_output.task_name, None)
                if samples is not None
                else None
            )
580
            doc_iterator = task.doc_iterator(
581
582
583
584
                rank=RANK,
                limit=limit,
                world_size=WORLD_SIZE,
                samples=indices,
585
            )
586
            for doc_id, doc in doc_iterator:
587
588
589
590
                if indices:
                    doc_id_true = indices[doc_id]
                else:
                    doc_id_true = doc_id
591
                requests = instances_by_doc_id[doc_id]
lintangsutawika's avatar
lintangsutawika committed
592
                metrics = task.process_results(
593
                    doc, [req.filtered_resps[filter_key] for req in requests]
lintangsutawika's avatar
lintangsutawika committed
594
                )
595
596
597
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
598
                        "doc_id": doc_id_true,
599
600
601
602
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
603
604
605
                        "filtered_resps": [
                            req.filtered_resps[filter_key] for req in requests
                        ],
606
607
                        "filter": filter_key,
                        "metrics": list(metrics.keys()),
608
609
610
611
612
613
614
615
616
617
                        "doc_hash": hash_string(
                            json.dumps(
                                requests[0].doc,
                                indent=2,
                                default=handle_non_serializable,
                                ensure_ascii=False,
                            )
                        ),
                        "prompt_hash": hash_string(requests[0].arguments[0]),
                        "target_hash": hash_string(str(target)),
618
619
                    }
                    example.update(metrics)
620
                    task_output.logged_samples.append(example)
621
                for metric, value in metrics.items():
622
                    task_output.sample_metrics[(metric, filter_key)].append(value)
623

624
625
    if WORLD_SIZE > 1:
        # if multigpu, then gather data across all ranks to rank 0
626
        # first gather logged samples across all ranks
627
628
629
630
631
632
633
634
        for task_output in eval_tasks:
            if log_samples:
                # for task_name, task_samples in list(samples.items()):
                full_samples = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.logged_samples,
                    object_gather_list=full_samples,
                    dst=0,
635
                )
636

637
638
639
640
                if RANK == 0:
                    task_output.logged_samples = list(
                        itertools.chain.from_iterable(full_samples)
                    )
641

642
643
644
645
646
647
648
            # then collect metrics across all ranks
            for metrics in task_output.sample_metrics:
                metric_list = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.sample_metrics[metrics],
                    object_gather_list=metric_list,
                    dst=0,
649
                )
650
651
652
653
                if RANK == 0:
                    task_output.sample_metrics[metrics] = list(
                        itertools.chain.from_iterable(metric_list)
                    )
654

655
    if RANK == 0:
656
657
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
658
659
        for task_output in eval_tasks:
            task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
660
661
662
663
664
665
666
667
        (
            results,
            samples,
            configs,
            versions,
            num_fewshot,
            higher_is_better,
        ) = consolidate_results(eval_tasks)
Fabrizio Milo's avatar
Fabrizio Milo committed
668

669
        ### Calculate group metrics ###
lintangsutawika's avatar
lintangsutawika committed
670
        if bool(results):
Lintang Sutawika's avatar
Lintang Sutawika committed
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
            results, versions, show_group_table, *_ = consolidate_group_results(
                results, versions, task_dict
            )

        results_agg, group_agg = prepare_print_tasks(task_dict, results)
        subtask_list = get_subtask_list(task_dict)

        # collect all higher_is_better values for metrics
        # in the group's subtasks.
        # TODO: clean this up ; unify with the below metric_list loop?
        _higher_is_better = {}
        for group, task_list in subtask_list.items():
            if (
                len(task_list) != 0
            ):  # subtask list will list "task_name": [] for solo tasks
686
687
688
689
                for task in task_list:
                    for m, h in higher_is_better[task].items():
                        if m not in _higher_is_better.keys():
                            _higher_is_better[m] = h
lintangsutawika's avatar
lintangsutawika committed
690

Lintang Sutawika's avatar
Lintang Sutawika committed
691
692
693
694
695
696
697
698
699
700
                        if (
                            m in _higher_is_better
                            and _higher_is_better[m] is not None
                            and _higher_is_better[m] != h
                        ):
                            eval_logger.warning(
                                f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
                            )
                            _higher_is_better[m] = None
                higher_is_better[group] = _higher_is_better
701

702
        results_dict = {
703
            "results": dict(results_agg.items()),
Lintang Sutawika's avatar
Lintang Sutawika committed
704
705
706
707
708
709
            **(
                {"groups": dict(group_agg.items())}
                if (bool(group_agg) & show_group_table)
                else {}
            ),
            "group_subtasks": dict(reversed(subtask_list.items())),
710
711
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
712
            "n-shot": dict(sorted(num_fewshot.items())),
713
            "higher_is_better": dict(sorted(higher_is_better.items())),
714
715
716
            "n-samples": {
                task_output.task_name: {
                    "original": len(task_output.task.eval_docs),
KonradSzafer's avatar
KonradSzafer committed
717
718
719
720
                    "effective": min(
                        limit if limit else len(task_output.task.eval_docs),
                        len(task_output.task.eval_docs),
                    ),
721
                }
Chenjie Luo's avatar
Chenjie Luo committed
722
                for task_output, limit in zip(eval_tasks, limits)
723
            },
724
        }
725
726
727
728
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
729

730
731
    else:
        return None
732
733


artemorloff's avatar
artemorloff committed
734
735
736
737
738
739
# def request_caching_arg_to_dict(cache_requests: str) -> dict:
#     request_caching_args = {
#         "cache_requests": cache_requests in {"true", "refresh"},
#         "rewrite_requests_cache": cache_requests == "refresh",
#         "delete_requests_cache": cache_requests == "delete",
#     }
740

artemorloff's avatar
artemorloff committed
741
#     return request_caching_args