tune_hooks.py 1.79 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# encoding: utf-8
"""
@author:  xingyu liao
@contact: sherlockliao01@gmail.com
"""

import torch
from ray import tune

from fastreid.engine.hooks import EvalHook, flatten_results_dict
from fastreid.utils.checkpoint import Checkpointer


class TuneReportHook(EvalHook):
    def __init__(self, eval_period, eval_function):
        super().__init__(eval_period, eval_function)
        self.step = 0

    def _do_eval(self):
        results = self._func()

        if results:
            assert isinstance(
                results, dict
            ), "Eval function must return a dict. Got {} instead.".format(results)

            flattened_results = flatten_results_dict(results)
            for k, v in flattened_results.items():
                try:
                    v = float(v)
                except Exception:
                    raise ValueError(
                        "[EvalHook] eval_function should return a nested dict of float. "
                        "Got '{}: {}' instead.".format(k, v)
                    )

        # Remove extra memory cache of main process due to evaluation
        torch.cuda.empty_cache()

        self.step += 1

        # Here we save a checkpoint. It is automatically registered with
        # RayTune and will potentially be passed as the `checkpoint_dir`
        # parameter in future iterations.
        with tune.checkpoint_dir(step=self.step) as checkpoint_dir:
            additional_state = {"epoch": int(self.trainer.epoch)}
            # Change path of save dir where tune can find
            self.trainer.checkpointer.save_dir = checkpoint_dir
            self.trainer.checkpointer.save(name="checkpoint", **additional_state)

        metrics = dict(r1=results["Rank-1"], map=results["mAP"], score=(results["Rank-1"] + results["mAP"]) / 2)
        tune.report(**metrics)