epoch_based_runner.py 3.73 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import platform
import shutil

import torch
from torch.optim import Optimizer

import mmcv
from mmcv.runner import RUNNERS, EpochBasedRunner
from .checkpoint import save_checkpoint

try:
    import apex
except:
    print("apex is not installed")


@RUNNERS.register_module()
class EpochBasedRunnerAmp(EpochBasedRunner):
    """Epoch-based Runner with AMP support.

    This runner train models epoch by epoch.
    """

    def save_checkpoint(
        self,
        out_dir,
        filename_tmpl="epoch_{}.pth",
        save_optimizer=True,
        meta=None,
        create_symlink=True,
    ):
        """Save the checkpoint.

        Args:
            out_dir (str): The directory that checkpoints are saved.
            filename_tmpl (str, optional): The checkpoint filename template,
                which contains a placeholder for the epoch number.
                Defaults to 'epoch_{}.pth'.
            save_optimizer (bool, optional): Whether to save the optimizer to
                the checkpoint. Defaults to True.
            meta (dict, optional): The meta information to be saved in the
                checkpoint. Defaults to None.
            create_symlink (bool, optional): Whether to create a symlink
                "latest.pth" to point to the latest checkpoint.
                Defaults to True.
        """
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        elif isinstance(meta, dict):
            meta.update(epoch=self.epoch + 1, iter=self.iter)
        else:
            raise TypeError(f"meta should be a dict or None, but got {type(meta)}")
        if self.meta is not None:
            meta.update(self.meta)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            dst_file = osp.join(out_dir, "latest.pth")
            if platform.system() != "Windows":
                mmcv.symlink(filename, dst_file)
            else:
                shutil.copy(filepath, dst_file)

    def resume(self, checkpoint, resume_optimizer=True, map_location="default"):
        if map_location == "default":
            if torch.cuda.is_available():
                device_id = torch.cuda.current_device()
                checkpoint = self.load_checkpoint(
                    checkpoint,
                    map_location=lambda storage, loc: storage.cuda(device_id),
                )
            else:
                checkpoint = self.load_checkpoint(checkpoint)
        else:
            checkpoint = self.load_checkpoint(checkpoint, map_location=map_location)

        self._epoch = checkpoint["meta"]["epoch"]
        self._iter = checkpoint["meta"]["iter"]
        if "optimizer" in checkpoint and resume_optimizer:
            if isinstance(self.optimizer, Optimizer):
                self.optimizer.load_state_dict(checkpoint["optimizer"])
            elif isinstance(self.optimizer, dict):
                for k in self.optimizer.keys():
                    self.optimizer[k].load_state_dict(checkpoint["optimizer"][k])
            else:
                raise TypeError(
                    "Optimizer should be dict or torch.optim.Optimizer "
                    f"but got {type(self.optimizer)}"
                )

        if "amp" in checkpoint:
            apex.amp.load_state_dict(checkpoint["amp"])
            self.logger.info("load amp state dict")

        self.logger.info("resumed epoch %d, iter %d", self.epoch, self.iter)