evaluator.py 30.6 KB
Newer Older
Baber Abbasi's avatar
Baber Abbasi committed
1
import itertools
2
import json
3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
import random
5
import time
6
7
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
Baber Abbasi's avatar
Baber Abbasi committed
8

9
import numpy as np
Baber Abbasi's avatar
Baber Abbasi committed
10
import torch
lintangsutawika's avatar
lintangsutawika committed
11

lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
Lintang Sutawika's avatar
Lintang Sutawika committed
14
import lm_eval.api.task
Baber Abbasi's avatar
Baber Abbasi committed
15
import lm_eval.models
16
from lm_eval.caching.cache import delete_cache
17
from lm_eval.evaluator_utils import (
Lintang Sutawika's avatar
Lintang Sutawika committed
18
    consolidate_group_results,
19
20
    consolidate_results,
    get_sample_size,
Lintang Sutawika's avatar
Lintang Sutawika committed
21
    get_subtask_list,
22
23
24
25
26
    get_task_list,
    prepare_print_tasks,
    print_writeout,
    run_task_tests,
)
KonradSzafer's avatar
KonradSzafer committed
27
from lm_eval.loggers import EvaluationTracker
28
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
29
from lm_eval.tasks import TaskManager, get_task_dict
30
31
32
33
from lm_eval.utils import (
    handle_non_serializable,
    hash_string,
    positional_deprecated,
Baber Abbasi's avatar
Baber Abbasi committed
34
    setup_logging,
35
36
    simple_parse_args_string,
)
artemorloff's avatar
artemorloff committed
37
from lm_eval.api.eval_config import EvaluationConfig
38

Fabrizio Milo's avatar
Fabrizio Milo committed
39

40
41
if TYPE_CHECKING:
    from lm_eval.api.model import LM
Lintang Sutawika's avatar
Lintang Sutawika committed
42
    from lm_eval.api.task import Task
43

Lintang Sutawika's avatar
Lintang Sutawika committed
44
45
eval_logger = logging.getLogger(__name__)

46

47
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
48
def simple_evaluate(
artemorloff's avatar
artemorloff committed
49
50
    config: "EvaluationConfig",
    # TODO: bootstrap_iters is not passed from cli_evaluate
Ethan Smith's avatar
Ethan Smith committed
51
    bootstrap_iters: int = 100000,
KonradSzafer's avatar
KonradSzafer committed
52
    evaluation_tracker: Optional[EvaluationTracker] = None,
53
    task_manager: Optional[TaskManager] = None,
Fabrizio Milo's avatar
Fabrizio Milo committed
54
):
55
    """Instantiate and evaluate a model on a list of tasks.
56

57
58
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
59
60
    :param model_args: Optional[str, dict]
        String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
61
        Ignored if `model` argument is a LM object.
62
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
63
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
64
65
    :param num_fewshot: int
        Number of examples in few-shot context
66
    :param batch_size: int or str, optional
67
        Batch size for model
68
69
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
70
    :param device: str, optional
71
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
72
73
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
74
75
76
    :param cache_requests: bool, optional
        Speed up evaluation by caching the building of dataset requests. `None` if not caching.
    :param rewrite_requests_cache: bool, optional
Baber Abbasi's avatar
Baber Abbasi committed
77
        Rewrites all the request cache if set to `True`. `None` if not desired.
78
    :param delete_requests_cache: bool, optional
Baber Abbasi's avatar
Baber Abbasi committed
79
        Deletes all the request cache if set to `True`. `None` if not desired.
80
81
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
82
83
    :param samples: dictionary, optional
        Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
84
    :param bootstrap_iters:
85
        Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
Stephen Hogg's avatar
Stephen Hogg committed
86
87
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
88
    :param write_out: bool
89
90
91
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
KonradSzafer's avatar
KonradSzafer committed
92
93
    :param system_instruction: str
        System instruction to be applied to the prompt
94
95
96
97
98
    :param apply_chat_template: Union[bool, str]
        Specifies whether to apply a chat template to the prompt.
        - If set to True, the default chat template is applied.
        - If set to a string, applies the specified chat template by name.
        Defaults to False (no chat template applied).
KonradSzafer's avatar
KonradSzafer committed
99
100
    :param fewshot_as_multiturn: bool
        Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
Baber Abbasi's avatar
Baber Abbasi committed
101
102
    :param gen_kwargs: dict or comma-separated string
        Arguments for model generation
103
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
104
    :param verbosity: str
Lintang Sutawika's avatar
Lintang Sutawika committed
105
        Verbosity level for logging
Baber Abbasi's avatar
Baber Abbasi committed
106
107
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated
108
109
110
111
112
113
    :param random_seed: int
        Random seed for python's random module. If set to None, the seed will not be set.
    :param numpy_random_seed: int
        Random seed for numpy. If set to None, the seed will not be set.
    :param torch_random_seed: int
        Random seed for torch. If set to None, the seed will not be set.
114
115
    :param fewshot_random_seed: int
        Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
Baber Abbasi's avatar
Baber Abbasi committed
116
117
    :param metadata: dict
        Additional metadata to be added to the task manager. Will get passed to the download function of the task.
Baber Abbasi's avatar
Baber Abbasi committed
118

Baber Abbasi's avatar
Baber Abbasi committed
119
    return
120
        Dictionary of results
121
    """
artemorloff's avatar
artemorloff committed
122
123
    if config.verbosity is not None:
        setup_logging(verbosity=config.verbosity)
124
    start_date = time.time()
125

artemorloff's avatar
artemorloff committed
126
    if config.limit is not None and config.samples is not None:
127
128
129
130
        raise ValueError(
            "Either 'limit' or 'samples' must be None, but both are not None."
        )

artemorloff's avatar
artemorloff committed
131
132
    if isinstance(config.model_args, str) and (
        "instruct" in config.model_args and not config.apply_chat_template
Baber Abbasi's avatar
Baber Abbasi committed
133
134
135
136
137
    ):
        eval_logger.warning(
            "Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
        )

artemorloff's avatar
artemorloff committed
138
    if config.request_caching_args.get("delete_requests_cache", False):
139
140
141
        eval_logger.info("Deleting requests cache...")
        delete_cache()

142
    seed_message = []
artemorloff's avatar
artemorloff committed
143
    if config.seed[0] is not None:
144
        # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
artemorloff's avatar
artemorloff committed
145
146
        seed_message.append(f"Setting random seed to {config.seed[0]}")
        random.seed(config.seed[0])
147

artemorloff's avatar
artemorloff committed
148
149
150
    if config.seed[1] is not None:
        seed_message.append(f"Setting numpy seed to {config.seed[1]}")
        np.random.seed(config.seed[1])
151

artemorloff's avatar
artemorloff committed
152
153
154
    if config.seed[2] is not None:
        seed_message.append(f"Setting torch manual seed to {config.seed[2]}")
        torch.manual_seed(config.seed[2])
155

artemorloff's avatar
artemorloff committed
156
157
    if config.seed[3] is not None:
        seed_message.append(f"Setting fewshot manual seed to {config.seed[3]}")
158

159
160
161
    if seed_message:
        eval_logger.info(" | ".join(seed_message))

artemorloff's avatar
artemorloff committed
162
163
164
    if config.tasks is None:
        config.tasks = []
    if len(config.tasks) == 0:
165
166
167
        raise ValueError(
            "No tasks specified, or no tasks found. Please verify the task names."
        )
168

artemorloff's avatar
artemorloff committed
169
170
171
    if config.gen_kwargs is not None:
        if isinstance(config.gen_kwargs, str):
            config.gen_kwargs = simple_parse_args_string(config.gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
172
        eval_logger.warning(
artemorloff's avatar
artemorloff committed
173
            f"generation_kwargs: {config.gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
174
            "Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
175
        )
artemorloff's avatar
artemorloff committed
176
177
        if not config.gen_kwargs:
            config.gen_kwargs = None
lintangsutawika's avatar
lintangsutawika committed
178

artemorloff's avatar
artemorloff committed
179
180
    if isinstance(config.model, str):
        if config.model_args is None:
181
            eval_logger.warning("model_args not specified. Using defaults.")
artemorloff's avatar
artemorloff committed
182
            config.model_args = ""
183

artemorloff's avatar
artemorloff committed
184
        if isinstance(config.model_args, dict):
185
            eval_logger.info(
artemorloff's avatar
artemorloff committed
186
                f"Initializing {config.model} model, with arguments: {config.model_args}"
187
            )
artemorloff's avatar
artemorloff committed
188
189
            lm = lm_eval.api.registry.get_model(config.model).create_from_arg_obj(
                config.model_args,
190
                {
artemorloff's avatar
artemorloff committed
191
192
193
                    "batch_size": config.batch_size,
                    "max_batch_size": config.max_batch_size,
                    "device": config.device,
194
195
196
197
                },
            )

        else:
198
            eval_logger.info(
artemorloff's avatar
artemorloff committed
199
                f"Initializing {config.model} model, with arguments: {simple_parse_args_string(config.model_args)}"
200
            )
artemorloff's avatar
artemorloff committed
201
202
            lm = lm_eval.api.registry.get_model(config.model).create_from_arg_string(
                config.model_args,
203
                {
artemorloff's avatar
artemorloff committed
204
205
206
                    "batch_size": config.batch_size,
                    "max_batch_size": config.max_batch_size,
                    "device": config.device,
207
208
                },
            )
209
    else:
artemorloff's avatar
artemorloff committed
210
        if not isinstance(config.model, lm_eval.api.model.LM):
211
            raise TypeError(
artemorloff's avatar
artemorloff committed
212
                f"The value of `model` passed to simple_evaluate() was of type {type(config.model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
213
            )
214
        eval_logger.info("Using pre-initialized model")
artemorloff's avatar
artemorloff committed
215
        lm = config.model
216

artemorloff's avatar
artemorloff committed
217
218
    if config.use_cache is not None:
        eval_logger.info(f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}")
haileyschoelkopf's avatar
haileyschoelkopf committed
219
220
        lm = lm_eval.api.model.CachingLM(
            lm,
artemorloff's avatar
artemorloff committed
221
            config.use_cache
haileyschoelkopf's avatar
haileyschoelkopf committed
222
223
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
224
225
226
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
227
228
        )

229
    if task_manager is None:
artemorloff's avatar
artemorloff committed
230
        task_manager = TaskManager(metadata=config.metadata)
231

Baber Abbasi's avatar
Baber Abbasi committed
232
    task_dict = get_task_dict(
artemorloff's avatar
artemorloff committed
233
        config.tasks,
Baber Abbasi's avatar
Baber Abbasi committed
234
235
        task_manager,
    )
Baber Abbasi's avatar
Baber Abbasi committed
236

Lintang Sutawika's avatar
Lintang Sutawika committed
237
238
239
240
241
242
243
244
245
246
    # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
    # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
    def _adjust_config(task_dict):
        adjusted_task_dict = {}
        for task_name, task_obj in task_dict.items():
            if isinstance(task_obj, dict):
                adjusted_task_dict = {
                    **adjusted_task_dict,
                    **{task_name: _adjust_config(task_obj)},
                }
247

248
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
249
                if task_obj.get_config("output_type") == "generate_until":
artemorloff's avatar
artemorloff committed
250
                    if config.gen_kwargs is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
251
                        task_obj.set_config(
artemorloff's avatar
artemorloff committed
252
                            key="generation_kwargs", value=config.gen_kwargs, update=True
Lintang Sutawika's avatar
Lintang Sutawika committed
253
                        )
Baber Abbasi's avatar
Baber Abbasi committed
254
255
256
                    eval_logger.info(
                        f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
                    )
Lintang Sutawika's avatar
Lintang Sutawika committed
257

artemorloff's avatar
artemorloff committed
258
                if config.predict_only:
Lintang Sutawika's avatar
Lintang Sutawika committed
259
260
261
262
263
264
265
266
                    eval_logger.info(
                        f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                    )
                    # we have to change the class properties post-hoc. This is pretty hacky.
                    task_obj.override_metric(metric_name="bypass")

                # override tasks' fewshot values to the provided num_fewshot arg value
                # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
artemorloff's avatar
artemorloff committed
267
                if config.num_fewshot is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
268
269
270
271
272
273
274
275
                    if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
                        eval_logger.info(
                            f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                        )
                    else:
                        eval_logger.warning(
                            f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                        )
artemorloff's avatar
artemorloff committed
276
                        task_obj.set_config(key="num_fewshot", value=config.num_fewshot)
Lintang Sutawika's avatar
Lintang Sutawika committed
277
278
279
280
281
282
283
                else:
                    # if num_fewshot not provided, and the task does not define a default one, default to 0
                    if (
                        default_num_fewshot := task_obj.get_config("num_fewshot")
                    ) is None:
                        task_obj.set_config(key="num_fewshot", value=0)
                # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
artemorloff's avatar
artemorloff committed
284
                task_obj.set_fewshot_seed(seed=config.seed[3])
Lintang Sutawika's avatar
Lintang Sutawika committed
285
286
287
288
289
290

                adjusted_task_dict[task_name] = task_obj

        return adjusted_task_dict

    task_dict = _adjust_config(task_dict)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
291

artemorloff's avatar
artemorloff committed
292
293
    if config.check_integrity:
        run_task_tests(task_list=config.tasks)
Stephen Hogg's avatar
Stephen Hogg committed
294

KonradSzafer's avatar
KonradSzafer committed
295
296
    if evaluation_tracker is not None:
        evaluation_tracker.general_config_tracker.log_experiment_args(
artemorloff's avatar
artemorloff committed
297
298
299
300
301
            model_source=config.model,
            model_args=config.model_args,
            system_instruction=config.system_instruction,
            chat_template=lm.chat_template(config.apply_chat_template)
            if config.apply_chat_template
Baber Abbasi's avatar
Baber Abbasi committed
302
            else None,
artemorloff's avatar
artemorloff committed
303
            fewshot_as_multiturn=config.fewshot_as_multiturn,
KonradSzafer's avatar
KonradSzafer committed
304
305
        )

306
307
308
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
artemorloff's avatar
artemorloff committed
309
310
311
312
        limit=config.limit,
        samples=config.samples,
        cache_requests=config.cache_requests,
        rewrite_requests_cache=config.request_caching_args.get("rewrite_requests_cache", False),
Niklas Muennighoff's avatar
Niklas Muennighoff committed
313
        bootstrap_iters=bootstrap_iters,
artemorloff's avatar
artemorloff committed
314
315
316
317
318
319
320
        write_out=config.write_out,
        log_samples=True if config.predict_only else config.log_samples,
        system_instruction=config.system_instruction,
        apply_chat_template=config.apply_chat_template,
        fewshot_as_multiturn=config.fewshot_as_multiturn,
        verbosity=config.verbosity,
        confirm_run_unsafe_code=config.confirm_run_unsafe_code,
321
    )
artemorloff's avatar
artemorloff committed
322
323
    if config.verbosity is not None:
        setup_logging(verbosity=config.verbosity)
324

325
    if lm.rank == 0:
artemorloff's avatar
artemorloff committed
326
327
328
329
        if isinstance(config.model, str):
            model_name = config.model
        elif hasattr(config.model, "config") and hasattr(config.model.config, "_name_or_path"):
            model_name = config.model.config._name_or_path
330
        else:
artemorloff's avatar
artemorloff committed
331
            model_name = type(config.model).__name__
332

333
334
        # add info about the model and few shot config
        results["config"] = {
335
            "model": model_name,
artemorloff's avatar
artemorloff committed
336
            "model_args": config.model_args,
337
        }
338
339
340
341
342
343
        # add more detailed model info if available
        if isinstance(lm, lm_eval.models.huggingface.HFLM):
            results["config"].update(lm.get_model_info())
        # add info about execution
        results["config"].update(
            {
artemorloff's avatar
artemorloff committed
344
                "batch_size": config.batch_size,
345
346
347
                "batch_sizes": (
                    list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
                ),
artemorloff's avatar
artemorloff committed
348
349
350
                "device": config.device,
                "use_cache": config.use_cache,
                "limit": config.limit,
351
                "bootstrap_iters": bootstrap_iters,
artemorloff's avatar
artemorloff committed
352
353
354
355
356
                "gen_kwargs": config.gen_kwargs,
                "random_seed": config.seed[0],
                "numpy_seed": config.seed[1],
                "torch_seed": config.seed[2],
                "fewshot_seed": config.seed[3],
357
358
            }
        )
359
        results["git_hash"] = get_git_commit_hash()
360
        results["date"] = start_date
361
        add_env_info(results)  # additional environment info to results
achervyakov's avatar
achervyakov committed
362
        add_tokenizer_info(results, lm)  # additional info about tokenizer
363
364
365
        return results
    else:
        return None
366

Leo Gao's avatar
Leo Gao committed
367

368
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
369
def evaluate(
370
    lm: "LM",
Fabrizio Milo's avatar
Fabrizio Milo committed
371
    task_dict,
Baber Abbasi's avatar
Baber Abbasi committed
372
    limit: Optional[int] = None,
373
    samples: Optional[dict] = None,
374
375
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
376
    bootstrap_iters: Optional[int] = 100000,
Ethan Smith's avatar
Ethan Smith committed
377
378
    write_out: bool = False,
    log_samples: bool = True,
KonradSzafer's avatar
KonradSzafer committed
379
    system_instruction: Optional[str] = None,
380
    apply_chat_template: Union[bool, str] = False,
KonradSzafer's avatar
KonradSzafer committed
381
    fewshot_as_multiturn: bool = False,
382
    verbosity: str = "INFO",
Hojin Lee's avatar
Hojin Lee committed
383
    confirm_run_unsafe_code: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
384
):
385
386
387
388
389
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
390
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
391
392
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
393
394
    :param samples: dictionary, optional
        Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
Hojin Lee's avatar
Hojin Lee committed
395
396
397
398
    :param cache_requests: bool, optional
        Speed up evaluation by caching the building of dataset requests.
    :param rewrite_requests_cache: bool, optional
        Rewrites all the request cache if set to `True`.
399
    :param bootstrap_iters:
400
        Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
401
    :param write_out: bool
402
403
404
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
KonradSzafer's avatar
KonradSzafer committed
405
406
    :param system_instruction: str
        System instruction to be applied to the prompt
407
408
409
410
411
    :param apply_chat_template: Union[bool, str]
        Specifies whether to apply a chat template to the prompt.
        - If set to True, the default chat template is applied.
        - If set to a string, applies the specified chat template by name.
        Defaults to False (no chat template applied).
KonradSzafer's avatar
KonradSzafer committed
412
413
    :param fewshot_as_multiturn: bool
        Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
Hojin Lee's avatar
Hojin Lee committed
414
415
416
417
    :param verbosity: str
        Verbosity level for logging
    :param confirm_run_unsafe_code: bool
        Whether to confirm running tasks marked as unsafe.
418
419
420
    :return
        Dictionary of results
    """
421

422
423
424
425
426
427
    if limit is not None and samples is not None:
        raise ValueError(
            "Either 'limit' or 'samples' must be None, but both are not None."
        )
    if samples is not None:
        eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
428
429
430
431
    if apply_chat_template:
        eval_logger.warning(
            "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
        )
432
    # tracks all Instances/requests a model must generate output on.
433
    requests = defaultdict(list)
434
435
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
436
    padding_requests = defaultdict(int)
437

438
    # get lists of group hierarchy and each type of request
Lintang Sutawika's avatar
Lintang Sutawika committed
439
    eval_tasks = get_task_list(task_dict)
440
    if not log_samples:
441
        if not all(
442
443
            "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
            for task_output in eval_tasks
444
445
        ):
            raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
446

Hojin Lee's avatar
Hojin Lee committed
447
448
449
    # validation checks:
    # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
    # 2.are we running code that is marked as unsafe.
450
    incompatible_tasks = []
451
452
    for task_output in eval_tasks:
        task: Task = task_output.task
453
454
455

        if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
            incompatible_tasks.append(task_output.task_name)
Hojin Lee's avatar
Hojin Lee committed
456
457
458
459
        elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
            raise ValueError(
                f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
            )
460
461
462
463
464
465
466
467
468
    if len(incompatible_tasks) > 0:
        if not getattr(lm, "MULTIMODAL", False):
            raise ValueError(
                f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
            )
        else:
            raise ValueError(
                f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
            )
Hojin Lee's avatar
Hojin Lee committed
469
    # end validation check
470

Chenjie Luo's avatar
Chenjie Luo committed
471
472
473
    # Cache the limit arg.
    limit_arg = limit
    limits = []
474
475
476
    for task_output in eval_tasks:
        task: Task = task_output.task

Chenjie Luo's avatar
Chenjie Luo committed
477
478
        limit = get_sample_size(task, limit_arg)
        limits.append(limit)
479
480
        task.build_all_requests(
            limit=limit,
481
482
483
            samples=samples.get(task_output.task_name, None)
            if samples is not None
            else samples,
484
485
486
487
            rank=lm.rank,
            world_size=lm.world_size,
            cache_requests=cache_requests,
            rewrite_requests_cache=rewrite_requests_cache,
KonradSzafer's avatar
KonradSzafer committed
488
            system_instruction=system_instruction,
489
            apply_chat_template=bool(apply_chat_template),
KonradSzafer's avatar
KonradSzafer committed
490
            fewshot_as_multiturn=fewshot_as_multiturn,
491
492
493
494
495
496
            chat_template=getattr(lm, "apply_chat_template")
            if apply_chat_template
            else None,
            tokenizer_name=getattr(lm, "tokenizer_name", "")
            if apply_chat_template
            else "",
497
        )
498
        eval_logger.debug(
499
            f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
haileyschoelkopf's avatar
haileyschoelkopf committed
500
501
        )
        if write_out:
502
            print_writeout(task)
503
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
504
505
506
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
507
508

        if lm.world_size > 1:
509
510
511
512
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
513
514
515
516
517
518
            # "multiple_choice" task types dispatch (several) "loglikelihood" request types
            reqtype = (
                "loglikelihood"
                if task.OUTPUT_TYPE == "multiple_choice"
                else task.OUTPUT_TYPE
            )
519
            # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
520
            numpad = max(gathered_item) - gathered_item[lm.rank]
521
522
            # todo: may not account for padding in cases like SquadV2 which has multiple req types
            padding_requests[reqtype] += numpad
523

524
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
525
526
    # execute each type of request
    for reqtype, reqs in requests.items():
527
        eval_logger.info(f"Running {reqtype} requests")
528
529
530
531
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
532

533
534
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
535
536
                cloned_reqs.extend([req] * req.repeats)

537
538
539
540
541
542
543
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

544
545
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
546

547
548
    RANK = lm.rank
    WORLD_SIZE = lm.world_size
549
550
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
Chenjie Luo's avatar
Chenjie Luo committed
551
    for task_output, limit in zip(eval_tasks, limits):
552
        task = task_output.task
553
554
        task.apply_filters()

555
556
        ### Collect values of metrics on all datapoints ###
        # # unpack results and sort back in order and return control to Task
haileyschoelkopf's avatar
haileyschoelkopf committed
557
        # TODO: make it possible to use a different metric per filter
558
        # Pre-process task.instances to group by doc_id
559
        instances_by_doc_id = defaultdict(list)
560
561
562
563
564
        for instance in task.instances:
            instances_by_doc_id[instance.doc_id].append(instance)
        # Sort instances within each group
        for instances in instances_by_doc_id.values():
            instances.sort(key=lambda x: x.idx)
haileyschoelkopf's avatar
haileyschoelkopf committed
565
        # iterate over different filters used
566
        for filter_key in task.instances[0].filtered_resps.keys():
567
568
569
570
571
            indices = (
                samples.get(task_output.task_name, None)
                if samples is not None
                else None
            )
572
            doc_iterator = task.doc_iterator(
573
574
575
576
                rank=RANK,
                limit=limit,
                world_size=WORLD_SIZE,
                samples=indices,
577
            )
578
            for doc_id, doc in doc_iterator:
579
580
581
582
                if indices:
                    doc_id_true = indices[doc_id]
                else:
                    doc_id_true = doc_id
583
                requests = instances_by_doc_id[doc_id]
lintangsutawika's avatar
lintangsutawika committed
584
                metrics = task.process_results(
585
                    doc, [req.filtered_resps[filter_key] for req in requests]
lintangsutawika's avatar
lintangsutawika committed
586
                )
587
588
589
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
590
                        "doc_id": doc_id_true,
591
592
593
594
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
595
596
597
                        "filtered_resps": [
                            req.filtered_resps[filter_key] for req in requests
                        ],
598
599
                        "filter": filter_key,
                        "metrics": list(metrics.keys()),
600
601
602
603
604
605
606
607
608
609
                        "doc_hash": hash_string(
                            json.dumps(
                                requests[0].doc,
                                indent=2,
                                default=handle_non_serializable,
                                ensure_ascii=False,
                            )
                        ),
                        "prompt_hash": hash_string(requests[0].arguments[0]),
                        "target_hash": hash_string(str(target)),
610
611
                    }
                    example.update(metrics)
612
                    task_output.logged_samples.append(example)
613
                for metric, value in metrics.items():
614
                    task_output.sample_metrics[(metric, filter_key)].append(value)
615

616
617
    if WORLD_SIZE > 1:
        # if multigpu, then gather data across all ranks to rank 0
618
        # first gather logged samples across all ranks
619
620
621
622
623
624
625
626
        for task_output in eval_tasks:
            if log_samples:
                # for task_name, task_samples in list(samples.items()):
                full_samples = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.logged_samples,
                    object_gather_list=full_samples,
                    dst=0,
627
                )
628

629
630
631
632
                if RANK == 0:
                    task_output.logged_samples = list(
                        itertools.chain.from_iterable(full_samples)
                    )
633

634
635
636
637
638
639
640
            # then collect metrics across all ranks
            for metrics in task_output.sample_metrics:
                metric_list = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.sample_metrics[metrics],
                    object_gather_list=metric_list,
                    dst=0,
641
                )
642
643
644
645
                if RANK == 0:
                    task_output.sample_metrics[metrics] = list(
                        itertools.chain.from_iterable(metric_list)
                    )
646

647
    if RANK == 0:
648
649
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
650
651
        for task_output in eval_tasks:
            task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
652
653
654
655
656
657
658
659
        (
            results,
            samples,
            configs,
            versions,
            num_fewshot,
            higher_is_better,
        ) = consolidate_results(eval_tasks)
Fabrizio Milo's avatar
Fabrizio Milo committed
660

661
        ### Calculate group metrics ###
lintangsutawika's avatar
lintangsutawika committed
662
        if bool(results):
Lintang Sutawika's avatar
Lintang Sutawika committed
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
            results, versions, show_group_table, *_ = consolidate_group_results(
                results, versions, task_dict
            )

        results_agg, group_agg = prepare_print_tasks(task_dict, results)
        subtask_list = get_subtask_list(task_dict)

        # collect all higher_is_better values for metrics
        # in the group's subtasks.
        # TODO: clean this up ; unify with the below metric_list loop?
        _higher_is_better = {}
        for group, task_list in subtask_list.items():
            if (
                len(task_list) != 0
            ):  # subtask list will list "task_name": [] for solo tasks
678
679
680
681
                for task in task_list:
                    for m, h in higher_is_better[task].items():
                        if m not in _higher_is_better.keys():
                            _higher_is_better[m] = h
lintangsutawika's avatar
lintangsutawika committed
682

Lintang Sutawika's avatar
Lintang Sutawika committed
683
684
685
686
687
688
689
690
691
692
                        if (
                            m in _higher_is_better
                            and _higher_is_better[m] is not None
                            and _higher_is_better[m] != h
                        ):
                            eval_logger.warning(
                                f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
                            )
                            _higher_is_better[m] = None
                higher_is_better[group] = _higher_is_better
693

694
        results_dict = {
695
            "results": dict(results_agg.items()),
Lintang Sutawika's avatar
Lintang Sutawika committed
696
697
698
699
700
701
            **(
                {"groups": dict(group_agg.items())}
                if (bool(group_agg) & show_group_table)
                else {}
            ),
            "group_subtasks": dict(reversed(subtask_list.items())),
702
703
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
704
            "n-shot": dict(sorted(num_fewshot.items())),
705
            "higher_is_better": dict(sorted(higher_is_better.items())),
706
707
708
            "n-samples": {
                task_output.task_name: {
                    "original": len(task_output.task.eval_docs),
KonradSzafer's avatar
KonradSzafer committed
709
710
711
712
                    "effective": min(
                        limit if limit else len(task_output.task.eval_docs),
                        len(task_output.task.eval_docs),
                    ),
713
                }
Chenjie Luo's avatar
Chenjie Luo committed
714
                for task_output, limit in zip(eval_tasks, limits)
715
            },
716
        }
717
718
719
720
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
721

722
723
    else:
        return None
724
725


artemorloff's avatar
artemorloff committed
726
727
728
729
730
731
# def request_caching_arg_to_dict(cache_requests: str) -> dict:
#     request_caching_args = {
#         "cache_requests": cache_requests in {"true", "refresh"},
#         "rewrite_requests_cache": cache_requests == "refresh",
#         "delete_requests_cache": cache_requests == "delete",
#     }
732

artemorloff's avatar
artemorloff committed
733
#     return request_caching_args