fsmn.py 3.6 KB
Newer Older
liugh5's avatar
liugh5 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
FSMN Pytorch Version
"""
import torch.nn as nn
import torch.nn.functional as F


class FeedForwardNet(nn.Module):
    """ A two-feed-forward-layer module """

    def __init__(self, d_in, d_hid, d_out, kernel_size=[1, 1], dropout=0.1):
        super().__init__()

        # Use Conv1D
        # position-wise
        self.w_1 = nn.Conv1d(
            d_in,
            d_hid,
            kernel_size=kernel_size[0],
            padding=(kernel_size[0] - 1) // 2,
        )
        # position-wise
        self.w_2 = nn.Conv1d(
            d_hid,
            d_out,
            kernel_size=kernel_size[1],
            padding=(kernel_size[1] - 1) // 2,
            bias=False,
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        output = x.transpose(1, 2)
        output = F.relu(self.w_1(output))
        output = self.dropout(output)
        output = self.w_2(output)
        output = output.transpose(1, 2)

        return output


class MemoryBlockV2(nn.Module):
    def __init__(self, d, filter_size, shift, dropout=0.0):
        super(MemoryBlockV2, self).__init__()

        left_padding = int(round((filter_size - 1) / 2))
        right_padding = int((filter_size - 1) / 2)
        if shift > 0:
            left_padding += shift
            right_padding -= shift

        self.lp, self.rp = left_padding, right_padding

        self.conv_dw = nn.Conv1d(d, d, filter_size, 1, 0, groups=d, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, mask=None):
        if mask is not None:
            input = input.masked_fill(mask.unsqueeze(-1), 0)

        x = F.pad(input, (0, 0, self.lp, self.rp, 0, 0), mode="constant", value=0.0)
        output = (
            self.conv_dw(x.contiguous().transpose(1, 2)).contiguous().transpose(1, 2)
        )
        output += input
        output = self.dropout(output)

        if mask is not None:
            output = output.masked_fill(mask.unsqueeze(-1), 0)

        return output


class FsmnEncoderV2(nn.Module):
    def __init__(
        self,
        filter_size,
        fsmn_num_layers,
        input_dim,
        num_memory_units,
        ffn_inner_dim,
        dropout=0.0,
        shift=0,
    ):
        super(FsmnEncoderV2, self).__init__()

        self.filter_size = filter_size
        self.fsmn_num_layers = fsmn_num_layers
        self.num_memory_units = num_memory_units
        self.ffn_inner_dim = ffn_inner_dim
        self.dropout = dropout
        self.shift = shift
        if not isinstance(shift, list):
            self.shift = [shift for _ in range(self.fsmn_num_layers)]

        self.ffn_lst = nn.ModuleList()
        self.ffn_lst.append(
            FeedForwardNet(input_dim, ffn_inner_dim, num_memory_units, dropout=dropout)
        )
        for i in range(1, fsmn_num_layers):
            self.ffn_lst.append(
                FeedForwardNet(
                    num_memory_units, ffn_inner_dim, num_memory_units, dropout=dropout
                )
            )

        self.memory_block_lst = nn.ModuleList()
        for i in range(fsmn_num_layers):
            self.memory_block_lst.append(
                MemoryBlockV2(num_memory_units, filter_size, self.shift[i], dropout)
            )

    def forward(self, input, mask=None):
        x = F.dropout(input, self.dropout, self.training)
        for (ffn, memory_block) in zip(self.ffn_lst, self.memory_block_lst):
            context = ffn(x)
            memory = memory_block(context, mask)
            memory = F.dropout(memory, self.dropout, self.training)
            if memory.size(-1) == x.size(-1):
                memory += x
            x = memory

        return x