evaluator.py 19.5 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import random
Leo Gao's avatar
Leo Gao committed
2
import itertools
FarzanehNakhaee's avatar
FarzanehNakhaee committed
3
import json
lintangsutawika's avatar
lintangsutawika committed
4
import collections
FarzanehNakhaee's avatar
FarzanehNakhaee committed
5
6
import logging
import sys
lintangsutawika's avatar
lintangsutawika committed
7

8
9
import torch

10
import numpy as np
lintangsutawika's avatar
lintangsutawika committed
11
12

import lm_eval.api
13
import lm_eval.tasks
lintangsutawika's avatar
lintangsutawika committed
14
import lm_eval.benchmarks
lintangsutawika's avatar
lintangsutawika committed
15
import lm_eval.models
lintangsutawika's avatar
lintangsutawika committed
16
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
17
import lm_eval.api.registry
lintangsutawika's avatar
lintangsutawika committed
18

lintangsutawika's avatar
lintangsutawika committed
19
20
21
22
from lm_eval.utils import (
    positional_deprecated,
    run_task_tests,
    make_table,
23
    create_iterator,
lintangsutawika's avatar
lintangsutawika committed
24
25
    get_git_commit_hash,
)
26

lintangsutawika's avatar
lintangsutawika committed
27
28
from lm_eval.logger import eval_logger

FarzanehNakhaee's avatar
FarzanehNakhaee committed
29
30
31
32
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))

Fabrizio Milo's avatar
Fabrizio Milo committed
33

34
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
35
36
37
38
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
39
    num_fewshot=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
40
    batch_size=None,
41
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
42
    device=None,
haileyschoelkopf's avatar
haileyschoelkopf committed
43
    use_cache=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
44
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
45
46
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
47
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
48
49
    write_out: bool = False,
    log_samples: bool = True,
Fabrizio Milo's avatar
Fabrizio Milo committed
50
):
51
    """Instantiate and evaluate a model on a list of tasks.
52

53
54
55
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
56
        String arguments for each model class, see LM.create_from_arg_string.
57
58
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
59
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
60
61
    :param num_fewshot: int
        Number of examples in few-shot context
62
    :param batch_size: int or str, optional
63
        Batch size for model
64
65
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
66
    :param device: str, optional
67
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
68
69
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
70
71
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
72
73
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
74
75
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
76
    :param write_out: bool
77
78
79
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
80
    :return
81
        Dictionary of results
82
    """
83
    random.seed(0)
84
    np.random.seed(1234)
85
86
87
    torch.manual_seed(
        1234
    )  # TODO: this may affect training runs that are run with evaluation mid-run.
88

89
90
91
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
92
93

    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
94
95
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
96
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
97
98
99
100
101
102
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
103
        )
104
    else:
105
        assert isinstance(model, lm_eval.api.model.LM)
106
        lm = model
107

haileyschoelkopf's avatar
haileyschoelkopf committed
108
109
110
111
112
113
114
115
116
117
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
            + "_rank" + str(lm.rank) + ".db",
        )

118
119
    task_dict = lm_eval.tasks.get_task_dict(tasks)
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
120
121
122
        task_obj = task_dict[task_name]
        if type(task_obj) == tuple:
            group, task_obj = task_obj
123
124
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
125
126

        config = task_obj._config
127
128
129
130
131
132
133
        if num_fewshot is not None:
            if config["num_fewshot"] > 0:
                default_num_fewshot = config["num_fewshot"]
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )

Lintang Sutawika's avatar
Lintang Sutawika committed
134
            task_obj._config["num_fewshot"] = num_fewshot
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
135

Stephen Hogg's avatar
Stephen Hogg committed
136
    if check_integrity:
137
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
138

139
140
141
142
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
143
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
144
        decontamination_ngrams_path=decontamination_ngrams_path,
145
        write_out=write_out,
146
        log_samples=log_samples,
147
    )
148

149
150
151
    if lm.rank == 0:
        # add info about the model and few shot config
        results["config"] = {
lintangsutawika's avatar
lintangsutawika committed
152
153
154
            "model": model
            if isinstance(model, str)
            else model.model.config._name_or_path,
155
156
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
157
158
159
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
160
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
161
            "use_cache": use_cache,
162
163
164
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
        }
165
        results["git_hash"] = get_git_commit_hash()
166
167
168
        return results
    else:
        return None
169

Leo Gao's avatar
Leo Gao committed
170

171
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
172

Fabrizio Milo's avatar
Fabrizio Milo committed
173

174
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
175
176
177
178
def evaluate(
    lm,
    task_dict,
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
179
    bootstrap_iters: int = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
180
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
181
182
    write_out: bool = False,
    log_samples: bool = True,
Fabrizio Milo's avatar
Fabrizio Milo committed
183
):
184
185
186
187
188
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
189
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
190
191
192
193
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
194
    :param write_out: bool
195
196
197
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
198
199
200
    :return
        Dictionary of results
    """
201

lintangsutawika's avatar
lintangsutawika committed
202
    # decontaminate = decontamination_ngrams_path is not None
203

204
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
205
    results = collections.defaultdict(dict)
206
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
207
    versions = collections.defaultdict(dict)
208
    # Tracks the YAML configs of all chosen tasks.
209
    configs = collections.defaultdict(dict)
210
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
211
    samples = collections.defaultdict(list)
212
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
213
    requests = collections.defaultdict(list)
Lintang Sutawika's avatar
Lintang Sutawika committed
214
    # Stores task scores based on task grouping.
215
    results_agg = collections.defaultdict(dict)
lintangsutawika's avatar
lintangsutawika committed
216
    groups_agg = collections.defaultdict(dict)
217
218
219
    # tracks if a task was chosen via user selecting a group containing it
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
220
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
221
222
    task_hierarchy = collections.defaultdict(list)
    task_order = collections.defaultdict(int)
223

224
    # get lists of each type of request
225
    for task_name, task in task_dict.items():
226
        if type(task) == tuple:
lintangsutawika's avatar
lintangsutawika committed
227
228
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
229
        else:
lintangsutawika's avatar
lintangsutawika committed
230
231
232
233
            task_hierarchy[task_name] = []

        if task is None:
            continue
234

Leo Gao's avatar
Leo Gao committed
235
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
236
237
        configs[task_name] = dict(task.dump_config())

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
238
        if limit is not None:
239
240
241
242
243
244
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
245
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
246

247
248
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

haileyschoelkopf's avatar
haileyschoelkopf committed
249
250
251
252
253
254
255
        eval_logger.info(
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
256
257
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
258
259
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
260
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
261
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
262

263
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
264
265
        reqtype = (
            "loglikelihood"
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
266
            if task.OUTPUT_TYPE == "multiple_choice"
lintangsutawika's avatar
lintangsutawika committed
267
268
269
            else task.OUTPUT_TYPE
        )  # TODO: this is hacky, fix in task.py
        requests[reqtype].extend(task.instances)
270
271

        if lm.world_size > 1:
272
273
274
275
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
276

277
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
278
            numpad = max(gathered_item) - gathered_item[lm.rank]
279
            padding_requests[task.OUTPUT_TYPE] += numpad
280

281
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
282
283
    # execute each type of request
    for reqtype, reqs in requests.items():
lintangsutawika's avatar
lintangsutawika committed
284
        eval_logger.info("Running {} requests".format(reqtype))
285
286
287
288
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
289

290
291
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
292
293
                cloned_reqs.extend([req] * req.repeats)

294
295
296
297
298
299
300
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

301
302
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
303

304
305
306
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
307
308
        if type(task) == tuple:
            group, task = task
309
310
            if task is None:
                continue
311
312
313
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
314
315
316
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
317
    for task_name, task in task_dict.items():
318
319
        if type(task) == tuple:
            group, task = task
320
321
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
322
323
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
324
        for key in task.instances[0].filtered_resps.keys():
325
326
327
328
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
329
                if task.has_test_docs()
330
331
332
333
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
334
            for doc_id, doc in doc_iterator:
335
336
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
337
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
338
339
340
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
341
342
343
344
345
346
347
348
349
350
351
352
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
353
354
355
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

356
    if lm.world_size > 1:
357
        # if multigpu, then gather data across all ranks
358
359
360
361
362
363
364
365
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
366
367
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
368
            numitem = 0
369
            if type(items[0]) == tuple:
370
371
                numitem = len(items[0])

372
373
374
375
            if isinstance(items[0], (str, list)):
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
376

377
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
378
            else:
379
380
381
382
383
384
385
386
387
388
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
389

390
391
392
393
394
395
396
397
398
399
400
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
401

402
403
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
404

405
        vals = vals_torch
406

407
408
409
410
411
    if lm.rank == 0:
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
lintangsutawika's avatar
lintangsutawika committed
412
413
            metric_key = metric + "," + key

414
            if type(task) == tuple:
lintangsutawika's avatar
lintangsutawika committed
415
416
417
418
                group_name, task = task
            else:
                group_name = None

lintangsutawika's avatar
lintangsutawika committed
419
            task_score = task.aggregation()[metric](items)
lintangsutawika's avatar
lintangsutawika committed
420
421
422
423
424
425
426
427
428
429
430
431

            if group_name is not None:
                sample_metric_key = metric + "(sample avg)," + key
                task_metric_key = metric + "(task avg)," + key
                if task_metric_key in results[group_name]:
                    results[group_name][task_metric_key].append(task_score)
                    results[group_name][sample_metric_key].extend(items)
                else:
                    results[group_name][task_metric_key] = [task_score]
                    results[group_name][sample_metric_key] = items

            results[task_name][metric_key] = task_score
lintangsutawika's avatar
lintangsutawika committed
432

433
434
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
435
            if False:  # bootstrap_iters > 0:
haileyschoelkopf's avatar
haileyschoelkopf committed
436
437
438
439
440
441
                stderr = lm_eval.api.metrics.stderr_for_metric(
                    metric=task.aggregation()[metric],
                    bootstrap_iters=min(bootstrap_iters, 1000)
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
442

haileyschoelkopf's avatar
haileyschoelkopf committed
443
444
                if stderr is not None:
                    results[task_name][metric + "_stderr" + "," + key] = stderr(items)
Fabrizio Milo's avatar
Fabrizio Milo committed
445

lintangsutawika's avatar
lintangsutawika committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        # zero_order_groups = [group for group in task_hierarchy if task_hierarchy[group] == 0]

        # for task_name, task in task_dict.items():
        #     if type(task) == tuple:
        #         group_name, _ = task
        #     else:
        #         group_name = None

        #     scores = results[task_name]
        #     if group_name is not None:
        #         group_name = tab_dict[group_name] * "-" + group_name
        #         if group_name not in results_agg:
        #             results_agg[group_name] = {}

        #         for metric in scores:
        #             if metric in results_agg[group_name]:
        #                 results_agg[group_name][metric].append(scores[metric])
        #             else:
        #                 results_agg[group_name][metric] = [scores[metric]]

        #     tab_task_name = tab_dict[task_name] * "-" + task_name
        #     results_agg[tab_task_name] = scores
        #     versions[tab_task_name] = versions[task_name]

        # if bool(results_agg):
        #     for group in results_agg.keys():
        #         for metric in results_agg[group].keys():
        #             results_agg[group][metric] = np.average(results_agg[group][metric])
        #             versions[group] = "N/A"

        if bool(results):
            for task_or_group in results.keys():
                for metric in results[task_or_group].keys():
                    if type(results[task_or_group][metric]) == list:
                        results[task_or_group][metric] = np.average(results[task_or_group][metric])
                        versions[task_or_group] = "N/A"

lintangsutawika's avatar
lintangsutawika committed
483
484
485
        print("task_hierarchy")
        print(task_hierarchy)
        print("--")
lintangsutawika's avatar
lintangsutawika committed
486
487
488
489
490
491
492
        for group in task_hierarchy.keys():
            if group not in task_order:
                task_order[group] = 0

            for task in task_hierarchy[group]:
                if task in task_order:
                    task_order[task] += 1
493
                else:
lintangsutawika's avatar
lintangsutawika committed
494
                    task_order[task] = 1 + task_order[group]
495

lintangsutawika's avatar
lintangsutawika committed
496
497
498
        print("task_order")
        print(task_order)
        print("--")
lintangsutawika's avatar
lintangsutawika committed
499
500
501
502
503
504
        for task_or_group, order in task_order.items():
            tabbed_name = ">"*order+task_or_group
            results_agg[tabbed_name] = results[task_or_group]
            versions[tabbed_name] = versions[task_or_group]
            if (order == 0) and len(task_hierarchy[task_or_group]) > 0:
                groups_agg[task_or_group] = results[task_or_group]
lintangsutawika's avatar
lintangsutawika committed
505

506
        results_dict = {
507
            "results": dict(results_agg.items()),
508
            **(
509
                {
lintangsutawika's avatar
lintangsutawika committed
510
                    "groups": dict(groups_agg.items())
511
                }
lintangsutawika's avatar
lintangsutawika committed
512
                if bool(groups_agg)
513
514
515
516
                else {}
            ),
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
517
        }
518
519
520
521
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
522

523
524
    else:
        return None