callbacks.py 3.5 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
5
6
import logging
import os

7
8
9
import torch
import torch.nn as nn

10
11
12
13
_logger = logging.getLogger(__name__)


class Callback:
14
15
16
    """
    Callback provides an easy way to react to events like begin/end of epochs.
    """
17
18
19
20
21
22
23

    def __init__(self):
        self.model = None
        self.mutator = None
        self.trainer = None

    def build(self, model, mutator, trainer):
24
25
26
27
28
29
30
31
32
33
34
35
        """
        Callback needs to be built with model, mutator, trainer, to get updates from them.

        Parameters
        ----------
        model : nn.Module
            Model to be trained.
        mutator : nn.Module
            Mutator that mutates the model.
        trainer : BaseTrainer
            Trainer that is to call the callback.
        """
36
37
38
39
40
        self.model = model
        self.mutator = mutator
        self.trainer = trainer

    def on_epoch_begin(self, epoch):
41
42
43
44
45
46
47
48
        """
        Implement this to do something at the begin of epoch.

        Parameters
        ----------
        epoch : int
            Epoch number, starting from 0.
        """
49
50
51
        pass

    def on_epoch_end(self, epoch):
52
53
54
55
56
57
58
59
        """
        Implement this to do something at the end of epoch.

        Parameters
        ----------
        epoch : int
            Epoch number, starting from 0.
        """
60
61
62
63
64
65
66
67
68
        pass

    def on_batch_begin(self, epoch):
        pass

    def on_batch_end(self, epoch):
        pass


69
class LRSchedulerCallback(Callback):
70
71
72
73
74
75
76
77
    """
    Calls scheduler on every epoch ends.

    Parameters
    ----------
    scheduler : LRScheduler
        Scheduler to be called.
    """
78
79
80
81
82
83
84
    def __init__(self, scheduler, mode="epoch"):
        super().__init__()
        assert mode == "epoch"
        self.scheduler = scheduler
        self.mode = mode

    def on_epoch_end(self, epoch):
85
86
87
        """
        Call ``self.scheduler.step()`` on epoch end.
        """
88
89
90
91
        self.scheduler.step()


class ArchitectureCheckpoint(Callback):
92
93
94
95
96
97
98
99
    """
    Calls ``trainer.export()`` on every epoch ends.

    Parameters
    ----------
    checkpoint_dir : str
        Location to save checkpoints.
    """
100
101
102
103
104
105
    def __init__(self, checkpoint_dir):
        super().__init__()
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    def on_epoch_end(self, epoch):
106
107
108
        """
        Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end.
        """
109
110
111
112
113
114
        dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))
        _logger.info("Saving architecture to %s", dest_path)
        self.trainer.export(dest_path)


class ModelCheckpoint(Callback):
115
116
117
118
119
120
121
122
    """
    Calls ``trainer.export()`` on every epoch ends.

    Parameters
    ----------
    checkpoint_dir : str
        Location to save checkpoints.
    """
123
    def __init__(self, checkpoint_dir):
124
125
126
127
128
        super().__init__()
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    def on_epoch_end(self, epoch):
129
130
131
132
        """
        Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end.
        ``DataParallel`` object will have their inside modules exported.
        """
133
134
135
136
137
138
139
        if isinstance(self.model, nn.DataParallel):
            state_dict = self.model.module.state_dict()
        else:
            state_dict = self.model.state_dict()
        dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
        _logger.info("Saving model to %s", dest_path)
        torch.save(state_dict, dest_path)