augmentation.py 3.22 KB
Newer Older
Sehoon Kim's avatar
Sehoon Kim committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from ..utils import shape_util


class SpecAugmentation(tf.keras.Model):

    def __init__(
        self, 
        num_freq_masks=2,
        freq_mask_len=27,
        num_time_masks=5,
        time_mask_prop=0.05,
        name='specaug',
        **kwargs,
    ):
        super(SpecAugmentation, self).__init__(name=name, **kwargs)

        self.num_freq_masks = num_freq_masks
        self.freq_mask_len = freq_mask_len
        self.num_time_masks = num_time_masks
        self.time_mask_prop = time_mask_prop


    def time_mask(self, inputs, inputs_len):
        time_max = inputs_len
        B, T, F = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2]
        t = tf.random.uniform(shape=tf.shape(time_max), minval=0, maxval=self.time_mask_prop)
        t = tf.cast(tf.cast(time_max, tf.dtypes.float32) * t, 'int32')
        t0 = tf.random.uniform(shape=tf.shape(time_max), minval=0, maxval=1)
        t0 = tf.cast(tf.cast(time_max - t, tf.dtypes.float32) * t0, 'int32')
        t = tf.repeat(tf.reshape(t, (-1, 1)), T, axis=1)
        t0 = tf.repeat(tf.reshape(t0, (-1, 1)), T, axis=1)

        indices = tf.repeat(tf.reshape(tf.range(T), (1, -1)), B, axis=0)

        left_mask = tf.cast(tf.math.greater_equal(indices, t0), 'float32')
        right_mask = tf.cast(tf.math.less(indices, t0 + t), 'float32')
        mask = 1.0 - left_mask * right_mask
        masked_inputs = inputs * tf.reshape(mask, (B, T, 1, 1))
        return masked_inputs


    def frequency_mask(self, inputs, inputs_len):
        B, T, F = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2]
        f = tf.random.uniform(shape=tf.shape(inputs_len), minval=0, maxval=self.freq_mask_len, dtype='int32')
        f0 = tf.random.uniform(shape=tf.shape(inputs_len), minval=0, maxval=1)
        f0 = tf.cast(tf.cast(F - f, tf.dtypes.float32) * f0, 'int32')

        f = tf.repeat(tf.reshape(f, (-1, 1)), F, axis=1)
        f0 = tf.repeat(tf.reshape(f0, (-1, 1)), F, axis=1)

        indices = tf.repeat(tf.reshape(tf.range(F), (1, -1)), B, axis=0)
        left_mask = tf.cast(tf.math.greater_equal(indices, f0), 'float32')
        right_mask = tf.cast(tf.math.less(indices, f0 + f), 'float32')
        mask = 1.0 - left_mask * right_mask
        masked_inputs = inputs * tf.reshape(mask, (B, 1, F, 1))
        return masked_inputs


    @tf.function
    def call(self, inputs, inputs_len):
        masked_inputs = inputs
        for _ in range(self.num_time_masks):
            masked_inputs = self.time_mask(masked_inputs, inputs_len)
        for _ in range(self.num_freq_masks):
            masked_inputs = self.frequency_mask(masked_inputs, inputs_len)
        return masked_inputs