time_reduction.py 2.23 KB
Newer Older
Sehoon Kim's avatar
Sehoon Kim committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
from ...utils import shape_util

class TimeReductionLayer(tf.keras.layers.Layer):
    def __init__(
        self,
        input_dim,
        output_dim,
        kernel_size=5,
        stride=2,
        dropout=0.0,
        name="time_reduction",
        **kwargs,
    ):
        super(TimeReductionLayer, self).__init__(name=name, **kwargs)
        self.stride = stride
        self.kernel_size = kernel_size
        dw_max = kernel_size ** -0.5
        pw_max = input_dim ** -0.5
        self.dw_conv = tf.keras.layers.DepthwiseConv2D(
            kernel_size=(kernel_size, 1), strides=self.stride,
            padding="valid", name=f"{name}_dw_conv",
            depth_multiplier=1,
            depthwise_initializer=tf.keras.initializers.RandomUniform(minval=-dw_max, maxval=dw_max),
            bias_initializer=tf.keras.initializers.RandomUniform(minval=-dw_max, maxval=dw_max),
        )
        #self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation")
        self.pw_conv = tf.keras.layers.Conv2D(
            filters=output_dim, kernel_size=1, strides=1,
            padding="valid", name=f"{name}_pw_conv_2",
            kernel_initializer=tf.keras.initializers.RandomUniform(minval=-pw_max, maxval=pw_max),
            bias_initializer=tf.keras.initializers.RandomUniform(minval=-pw_max, maxval=pw_max),
        )
    def call(self, inputs, training=False, mask=None, pad_mask=None, **kwargs):
        B, T, E = shape_util.shape_list(inputs)
        outputs = tf.reshape(inputs, [B, T, 1, E])
        _pad_mask = tf.expand_dims(tf.expand_dims(pad_mask, -1), -1)
        outputs = outputs * tf.cast(_pad_mask, "float32")
        padding = max(0, self.kernel_size - self.stride)
        outputs = tf.pad(outputs, [[0, 0], [0, padding], [0, 0], [0, 0]])
        outputs = self.dw_conv(outputs, training=training)
        outputs = self.pw_conv(outputs, training=training)
        B, T, _, E = shape_util.shape_list(outputs)
        outputs = tf.reshape(outputs, [B, T, E])

        mask = mask[:, ::self.stride, ::self.stride]
        pad_mask = pad_mask[:, ::self.stride]
        _, L = shape_util.shape_list(pad_mask)
        outputs = tf.pad(outputs, [[0, 0], [0, L - T], [0, 0]])

        return outputs, mask, pad_mask