contrastive_trainer.py 8.81 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tensorflow as tf
from tensorflow import keras

from keras_cv.utils.train import convert_inputs_to_tf_dataset


class ContrastiveTrainer(keras.Model):
    """Creates a self-supervised contrastive trainer for a model.

    Args:
        encoder: a `keras.Model` to be pre-trained. In most cases, this encoder
            should not include a top dense layer.
        augmenter: a preprocessing layer to randomly augment input images for contrastive learning,
            or a tuple of two separate augmenters for the two sides of the contrastive pipeline.
        projector: a projection model for contrastive training, or a tuple of two separate
            projectors for the two sides of the contrastive pipeline. This shrinks
            the feature map produced by the encoder, and is usually a 1 or
            2-layer dense MLP.
        probe: An optional Keras layer or model which will be trained against
            class labels at train-time using the encoder output as input.
            Note that this should be specified iff training with labeled images.
            This predicts class labels based on the feature map produced by the
            encoder and is usually a 1 or 2-layer dense MLP.

    Returns:
      A `keras.Model` instance.


    Usage:
    ```python
    encoder = keras_cv.models.DenseNet121(include_rescaling=True, include_top=False, pooling="avg")
    augmenter = keras_cv.layers.preprocessing.RandomFlip()
    projector = keras.layers.Dense(64)
    probe = keras_cv.training.ContrastiveTrainer.linear_probe(classes=10)

    trainer = keras_cv.training.ContrastiveTrainer(
        encoder=encoder,
        augmenter=augmenter,
        projector=projector,
        probe=probe
    )

    trainer.compile(
        encoder_optimizer=keras.optimizers.Adam(),
        encoder_loss=keras_cv.losses.SimCLRLoss(temperature=0.5),
        probe_optimizer=keras.optimizers.Adam(),
        probe_loss=keras.losses.CategoricalCrossentropy(from_logits=True),
        probe_metrics=[keras.metrics.CategoricalAccuracy(name="probe_accuracy")]
    )

    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    y_train = keras.utils.to_categorical(y_train, 10)

    trainer.fit(x_train, y_train)
    ```

    """

    def __init__(
        self,
        encoder,
        augmenter,
        projector,
        probe=None,
    ):
        super().__init__()

        if encoder.output.shape.rank != 2:
            raise ValueError(
                f"`encoder` must have a flattened output.  Expected rank(encoder.output.shape)=2, got encoder.output.shape={encoder.output.shape}"
            )

        if type(augmenter) is tuple and len(augmenter) != 2:
            raise ValueError(
                "`augmenter` must be either a single augmenter or a tuple of exactly 2 augmenters."
            )

        if type(projector) is tuple and len(projector) != 2:
            raise ValueError(
                "`projector` must be either a single augmenter or a tuple of exactly 2 augmenters."
            )

        self.augmenters = (
            augmenter if type(augmenter) is tuple else (augmenter, augmenter)
        )
        self.encoder = encoder
        self.projectors = (
            projector if type(projector) is tuple else (projector, projector)
        )
        self.probe = probe

        self.loss_metric = keras.metrics.Mean(name="loss")

        if probe is not None:
            self.probe_loss_metric = keras.metrics.Mean(name="probe_loss")
            self.probe_metrics = []

    def compile(
        self,
        encoder_loss,
        encoder_optimizer,
        encoder_metrics=None,
        probe_optimizer=None,
        probe_loss=None,
        probe_metrics=None,
        **kwargs,
    ):
        super().compile(
            loss=encoder_loss,
            optimizer=encoder_optimizer,
            metrics=encoder_metrics,
            **kwargs,
        )

        if self.probe and not probe_optimizer:
            raise ValueError(
                "`probe_optimizer` must be specified when a probe is included."
            )

        if self.probe and not probe_loss:
            raise ValueError("`probe_loss` must be specified when a probe is included.")

        if "loss" in kwargs:
            raise ValueError(
                "`loss` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_loss` or `probe_loss`."
            )

        if "optimizer" in kwargs:
            raise ValueError(
                "`optimizer` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_optimizer` or `probe_optimizer`."
            )

        if "metrics" in kwargs:
            raise ValueError(
                "`metrics` parameter in ContrastiveTrainer.compile is ambiguous. Please specify `encoder_metrics` or `probe_metrics`."
            )

        if self.probe:
            self.probe_loss = probe_loss
            self.probe_optimizer = probe_optimizer
            self.probe_metrics = probe_metrics or []

    @property
    def metrics(self):
        metrics = [
            self.loss_metric,
        ]
        if self.probe:
            metrics += [self.probe_loss_metric]
            metrics += self.probe_metrics
        return super().metrics + metrics

    def fit(
        self,
        x=None,
        y=None,
        sample_weight=None,
        batch_size=None,
        **kwargs,
    ):
        dataset = convert_inputs_to_tf_dataset(
            x=x, y=y, sample_weight=sample_weight, batch_size=batch_size
        )

        dataset = dataset.map(self.run_augmenters, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        return super().fit(x=dataset, **kwargs)

    def run_augmenters(self, x, y=None):
        inputs = {"images": x}
        if y is not None:
            inputs["labels"] = y

        inputs["augmented_images_0"] = self.augmenters[0](x, training=True)
        inputs["augmented_images_1"] = self.augmenters[1](x, training=True)

        return inputs

    def train_step(self, data):
        images = data["images"]
        labels = data["labels"] if "labels" in data else None
        augmented_images_0 = data["augmented_images_0"]
        augmented_images_1 = data["augmented_images_1"]

        with tf.GradientTape() as tape:
            features_0 = self.encoder(augmented_images_0, training=True)
            features_1 = self.encoder(augmented_images_1, training=True)

            projections_0 = self.projectors[0](features_0, training=True)
            projections_1 = self.projectors[1](features_1, training=True)

            loss = self.compiled_loss(
                projections_0, projections_1, regularization_losses=self.encoder.losses
            )

        gradients = tape.gradient(
            loss,
            self.encoder.trainable_weights
            + self.projectors[0].trainable_weights
            + self.projectors[1].trainable_weights,
        )

        self.optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights
                + self.projectors[0].trainable_weights
                + self.projectors[1].trainable_weights,
            )
        )
        self.loss_metric.update_state(loss)

        if self.probe:
            if labels is None:
                raise ValueError("Targets must be provided when a probe is specified")
            with tf.GradientTape() as tape:
                features = tf.stop_gradient(self.encoder(images, training=False))
                class_logits = self.probe(features, training=True)
                probe_loss = self.probe_loss(labels, class_logits)
            gradients = tape.gradient(probe_loss, self.probe.trainable_weights)
            self.probe_optimizer.apply_gradients(
                zip(gradients, self.probe.trainable_weights)
            )
            self.probe_loss_metric.update_state(probe_loss)
            for metric in self.probe_metrics:
                metric.update_state(labels, class_logits)

        return {metric.name: metric.result() for metric in self.metrics}

    def call(self, inputs):
        raise NotImplementedError(
            "ContrastiveTrainer.call() is not implemented - please call your model directly."
        )

    @staticmethod
    def linear_probe(classes, **kwargs):
        return keras.Sequential(keras.layers.Dense(classes), **kwargs)