conformer.py 8.82 KB
Newer Older
Sehoon Kim's avatar
Sehoon Kim committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import tensorflow as tf
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.framework import ops
from tensorflow.python.eager import def_function

from .ctc import CtcModel
from .conformer_encoder import ConformerEncoder
from ..augmentations.augmentation import SpecAugmentation
from ..utils import math_util
from ..utils.training_utils import (
    _minimum_control_deps,
    reduce_per_replica,
    write_scalar_summaries,
)

class ConformerCtc(CtcModel):
    def __init__(
        self,
        vocabulary_size: int,
        encoder_subsampling: dict,
        encoder_dmodel: int = 144,
        encoder_num_blocks: int = 16,
        encoder_head_size: int = 36,
        encoder_num_heads: int = 4,
        encoder_mha_type: str = "relmha",
        encoder_kernel_size: int = 32,
        encoder_fc_factor: float = 0.5,
        encoder_dropout: float = 0,
        encoder_time_reduce_idx : list = None,
        encoder_time_recover_idx : list = None,
        encoder_conv_use_glu: bool = False,
        encoder_ds_subsample: bool = False,
        encoder_no_post_ln: bool = False,
        encoder_adaptive_scale: bool = False,
        encoder_fixed_arch: list = None,
        augmentation_config=None,
        name: str = "conformer",
        **kwargs,
    ) -> object:
        assert encoder_dmodel == encoder_num_heads * encoder_head_size
        if not isinstance(encoder_fixed_arch[0], list):
            encoder_fixed_arch = [encoder_fixed_arch] * encoder_num_blocks
        super().__init__(
            encoder=ConformerEncoder(
                subsampling=encoder_subsampling,
                dmodel=encoder_dmodel,
                num_blocks=encoder_num_blocks,
                head_size=encoder_head_size,
                num_heads=encoder_num_heads,
                mha_type=encoder_mha_type,
                kernel_size=encoder_kernel_size,
                fc_factor=encoder_fc_factor,
                dropout=encoder_dropout,
                time_reduce_idx=encoder_time_reduce_idx,
                time_recover_idx=encoder_time_recover_idx,
                conv_use_glu=encoder_conv_use_glu,
                ds_subsample=encoder_ds_subsample,
                no_post_ln=encoder_no_post_ln,
                adaptive_scale=encoder_adaptive_scale,
                fixed_arch=encoder_fixed_arch,
                name=f"{name}_encoder",
            ),
            decoder=tf.keras.layers.Conv1D(
                filters=vocabulary_size, kernel_size=1,
                strides=1, padding="same",
                name=f"{name}_logits"
            ),
            augmentation = SpecAugmentation(
                num_freq_masks=augmentation_config['freq_masking']['num_masks'],
                freq_mask_len=augmentation_config['freq_masking']['mask_factor'],
                num_time_masks=augmentation_config['time_masking']['num_masks'],
                time_mask_prop=augmentation_config['time_masking']['p_upperbound'],
                name=f"{name}_specaug"
            ) if augmentation_config is not None else None,
            vocabulary_size=vocabulary_size,
            name=name,
            **kwargs
        )
        self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor
        self.dmodel = encoder_dmodel

    # The following functions override the original function
    # in order to gather the outputs from multiple TPU cores

    def make_train_function(self):
        if self.train_function is not None:
            return self.train_function

        def step_function(model, iterator):
            """Runs a single training step."""

            def run_step(data):
                outputs = model.train_step(data)
                # Ensure counter is updated only if `train_step` succeeds.
                with ops.control_dependencies(_minimum_control_deps(outputs)):
                    model._train_counter.assign_add(1)  # pylint: disable=protected-access
                return outputs

            data = next(iterator)
            outputs = model.distribute_strategy.run(run_step, args=(data,))
            outputs = reduce_per_replica(outputs, self.distribute_strategy)
            write_scalar_summaries(outputs, step=model._train_counter)  # pylint: disable=protected-access
            return outputs

        if self._steps_per_execution.numpy().item() == 1:

            def train_function(iterator):
                """Runs a training execution with one step."""
                return step_function(self, iterator)
        else:

            def train_function(iterator):
                """Runs a training execution with multiple steps."""
                for _ in math_ops.range(self._steps_per_execution):
                    outputs = step_function(self, iterator)
                return outputs

        if not self.run_eagerly:
            train_function = def_function.function(
                train_function, experimental_relax_shapes=True)

        self.train_function = train_function

        if self._cluster_coordinator:
            self.train_function = lambda iterator: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
                train_function, args=(iterator,))

        return self.train_function

    def make_test_function(self):
        if self.test_function is not None:
            return self.test_function

        def step_function(model, iterator):
            """Runs a single evaluation step."""

            def run_step(data):
                outputs = model.test_step(data)
                # Ensure counter is updated only if `test_step` succeeds.
                with ops.control_dependencies(_minimum_control_deps(outputs)):
                    model._test_counter.assign_add(1)  # pylint: disable=protected-access
                return outputs

            data = next(iterator)
            outputs = model.distribute_strategy.run(run_step, args=(data,))
            outputs = reduce_per_replica(outputs, self.distribute_strategy)
            return outputs

        if self._steps_per_execution.numpy().item() == 1:

            def test_function(iterator):
                """Runs an evaluation execution with one step."""
                return step_function(self, iterator)
        else:

            def test_function(iterator):
                """Runs an evaluation execution with multiple steps."""
                for _ in math_ops.range(self._steps_per_execution):
                    outputs = step_function(self, iterator)
                return outputs

        if not self.run_eagerly:
            test_function = def_function.function(test_function, experimental_relax_shapes=True)

        self.test_function = test_function

        if self._cluster_coordinator:
            self.test_function = lambda iterator: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
                test_function, args=(iterator,))

        return self.test_function


class ConformerCtcAccumulate(ConformerCtc):
    def __init__(self, n_gradients: int = 1, **kwargs) -> object:
        super().__init__(**kwargs)
        self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor

        self.n_gradients = tf.constant(n_gradients, dtype=tf.int32, name="conformer/num_accumulated_gradients")
        self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="conformer/accumulate_step")

    def make(self, input_shape, batch_size=None):
        super().make(input_shape, batch_size)
        self.gradient_accumulation = [
                tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False, name=f"{v.name}/cached_accumulated_gradient") for v in self.trainable_variables
        ]

    def train_step(self, batch):
        """
        Args:
            batch ([tf.Tensor]): a batch of training data

        Returns:
            Dict[tf.Tensor]: a dict of validation metrics with keys are the name of metric

        """
        self.n_acum_step.assign_add(1)

        inputs, y_true = batch
        loss, y_pred, gradients = self.gradient_step(inputs, y_true)

        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign_add(gradients[i] / tf.cast(self.n_gradients, tf.float32))

        tf.cond(tf.equal(self.n_acum_step, self.n_gradients), self.apply_accu_gradients, lambda: None)

        self._metrics["loss"].update_state(loss)
        if 'WER' in self._metrics:
            self._metrics['WER'].update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def apply_accu_gradients(self):
        # Apply accumulated gradients
        self.optimizer.apply_gradients(zip(self.gradient_accumulation, 
                                           self.trainable_variables))

        # Reset
        self.n_acum_step.assign(0)
        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign(
                tf.zeros_like(self.trainable_variables[i],  dtype=tf.float32)
            )