trainer.py 8.22 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# encoding: utf-8
"""
@author:  xingyu liao
@contact: sherlockliao01@gmail.com
"""
import logging
import os
import time

from torch.nn.parallel import DistributedDataParallel
from torch.nn.utils import clip_grad_norm_

from fastreid.data.build import _root, build_reid_test_loader, build_reid_train_loader
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.transforms import build_transforms
from fastreid.engine import hooks
from fastreid.engine.defaults import DefaultTrainer, TrainerBase
from fastreid.engine.train_loop import SimpleTrainer, AMPTrainer
from fastreid.solver import build_optimizer
from fastreid.utils import comm
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger
from fastreid.utils.params import ContiguousParams
from .face_data import MXFaceDataset
from .face_data import TestFaceDataset
from .face_evaluator import FaceEvaluator
from .modeling import PartialFC
from .pfc_checkpointer import PfcPeriodicCheckpointer, PfcCheckpointer
from .utils_amp import MaxClipGradScaler


class FaceTrainer(DefaultTrainer):
    def __init__(self, cfg):
        TrainerBase.__init__(self)

        logger = logging.getLogger('fastreid.partial-fc.trainer')
        if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
        model = self.build_model(cfg)
        optimizer, param_wrapper = self.build_optimizer(cfg, model)

        if cfg.MODEL.HEADS.PFC.ENABLED:
            # fmt: off
            feat_dim      = cfg.MODEL.BACKBONE.FEAT_DIM
            embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
            num_classes   = cfg.MODEL.HEADS.NUM_CLASSES
            sample_rate   = cfg.MODEL.HEADS.PFC.SAMPLE_RATE
            cls_type      = cfg.MODEL.HEADS.CLS_LAYER
            scale         = cfg.MODEL.HEADS.SCALE
            margin        = cfg.MODEL.HEADS.MARGIN
            # fmt: on
            # Partial-FC module
            embedding_size = embedding_dim if embedding_dim > 0 else feat_dim
            self.pfc_module = PartialFC(embedding_size, num_classes, sample_rate, cls_type, scale, margin)
            self.pfc_optimizer, _ = build_optimizer(cfg, self.pfc_module, False)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
            )

        if cfg.MODEL.HEADS.PFC.ENABLED:
            mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
            grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100)
            self._trainer = PFCTrainer(model, data_loader, optimizer, param_wrapper,
                                       self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler)
        else:
            self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
                model, data_loader, optimizer, param_wrapper
            )

        self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
        if cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_scheduler = self.build_lr_scheduler(cfg, self.pfc_optimizer, self.iters_per_epoch)

        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            **self.scheduler,
        )

        if cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_checkpointer = PfcCheckpointer(
                self.pfc_module,
                cfg.OUTPUT_DIR,
                optimizer=self.pfc_optimizer,
                **self.pfc_scheduler,
            )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

    def build_hooks(self):
        ret = super().build_hooks()

        if self.cfg.MODEL.HEADS.PFC.ENABLED:
            # Make sure checkpointer is after writer
            ret.insert(
                len(ret) - 1,
                PfcPeriodicCheckpointer(self.pfc_checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
            )
            # partial fc scheduler hook
            ret.append(
                hooks.LRScheduler(self.pfc_optimizer, self.pfc_scheduler)
            )
        return ret

    def resume_or_load(self, resume=True):
        # Backbone loading state_dict
        super().resume_or_load(resume)
        # Partial-FC loading state_dict
        if self.cfg.MODEL.HEADS.PFC.ENABLED:
            self.pfc_checkpointer.resume_or_load('', resume=resume)

    @classmethod
    def build_train_loader(cls, cfg):
        path_imgrec = cfg.DATASETS.REC_PATH
        if path_imgrec != "":
            transforms = build_transforms(cfg, is_train=True)
            train_set = MXFaceDataset(path_imgrec, transforms)
            return build_reid_train_loader(cfg, train_set=train_set)
        else:
            return DefaultTrainer.build_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
        test_set = TestFaceDataset(dataset.carray, dataset.is_same)
        data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
        return data_loader, test_set.labels

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_dir=None):
        if output_dir is None:
            output_dir = os.path.join(cfg.OUTPUT_DIR, "visualization")
        data_loader, labels = cls.build_test_loader(cfg, dataset_name)
        return data_loader, FaceEvaluator(cfg, labels, dataset_name, output_dir)


class PFCTrainer(SimpleTrainer):
    """
    Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
    Partial FC: Training 10 Million Identities on a Single Machine
    See the original paper:
    https://arxiv.org/abs/2010.05222
    code based on:
    https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py
    """

    def __init__(self, model, data_loader, optimizer, param_wrapper, pfc_module, pfc_optimizer, amp_enabled,
                 grad_scaler):
        super().__init__(model, data_loader, optimizer, param_wrapper)

        self.pfc_module = pfc_module
        self.pfc_optimizer = pfc_optimizer
        self.amp_enabled = amp_enabled

        self.grad_scaler = grad_scaler

    def run_step(self):
        assert self.model.training, "[PFCTrainer] model was changed to eval mode!"
        start = time.perf_counter()

        data = next(self._data_loader_iter)
        data_time = time.perf_counter() - start

        features, targets = self.model(data)

        # Partial-fc backward
        f_grad, loss_v = self.pfc_module.forward_backward(features, targets, self.pfc_optimizer)

        if self.amp_enabled:
            features.backward(self.grad_scaler.scale(f_grad))
            self.grad_scaler.unscale_(self.optimizer)
            clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            features.backward(f_grad)
            clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
            self.optimizer.step()

        loss_dict = {"loss_cls": loss_v}
        self._write_metrics(loss_dict, data_time)

        self.pfc_optimizer.step()
        self.pfc_module.update()
        self.optimizer.zero_grad()
        self.pfc_optimizer.zero_grad()
        if isinstance(self.param_wrapper, ContiguousParams):
            self.param_wrapper.assert_buffer_is_valid()