model_s2s.py 4.01 KB
Newer Older
burchim's avatar
burchim committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright 2021, Maxime Burchi.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# PyTorch
import torch
import torch.nn as nn

# Base Model
from models.model import Model

# Encoders
from models.encoders import (
    ConformerEncoder
)

# Decoders
from models.decoders import (
    ConformerCrossDecoder,
    TransformerCrossDecoder
)

# Losses
from models.losses import (
    LossCE
)

# Ngram
import kenlm

class ModelS2S(Model):

    def __init__(self, encoder_params, decoder_params, tokenizer_params, training_params, decoding_params, name):
        super(ModelS2S, self).__init__(tokenizer_params, training_params, decoding_params, name)

        # Not Implemented
        raise Exception("Sequence-to-sequence model not implemented")

        # Encoder
        if encoder_params["arch"] == "Conformer":
            self.encoder = ConformerEncoder(encoder_params)
        else:
            raise Exception("Unknown encoder architecture:", encoder_params["arch"])

        # Decoder
        if decoder_params["arch"] == "Conformer":
            self.decoder = ConformerCrossDecoder(decoder_params)
        elif decoder_params["arch"] == "Transformer":
            self.decoder = TransformerCrossDecoder(decoder_params)
        else:
            raise Exception("Unknown decoder architecture:", decoder_params["arch"])

        # Joint Network
        self.fc = nn.Linear(encoder_params["dim_model"][-1] if isinstance(encoder_params["dim_model"], list) else encoder_params["dim_model"], tokenizer_params["vocab_size"])

        # Criterion
        self.criterion = LossCE()

        # Compile
        self.compile(training_params)

    def forward(self, batch):

        # Unpack Batch
        x, y, _ = batch

        # Audio Encoder (B, Taud) -> (B, T, Denc)
        x, _, attentions = self.encoder(x, None)

        # Add blank token
        y = torch.nn.functional.pad(y, pad=(1, 0, 0, 0), value=0)

        # Text Decoder (B, U + 1) -> (B, U + 1, Ddec)
        y = self.decoder(x, y)

        # FC Layer (B, T, Ddec) -> (B, T, V)
        logits = self.fc(y)

        return logits, attentions

    def distribute_strategy(self, rank):
        super(ModelS2S, self).distribute_strategy(rank)

        self.encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.encoder)
        self.encoder = torch.nn.parallel.DistributedDataParallel(self.encoder, device_ids=[self.rank])
        self.decoder = torch.nn.parallel.DistributedDataParallel(self.decoder, device_ids=[self.rank])
        self.fc = torch.nn.parallel.DistributedDataParallel(self.fc, device_ids=[self.rank])

    def parallel_strategy(self):
        super(ModelS2S, self).parallel_strategy()

        self.encoder = torch.nn.DataParallel(self.encoder)
        self.decoder = torch.nn.DataParallel(self.decoder)
        self.fc = torch.nn.DataParallel(self.fc)

    def summary(self, show_dict=False):

        print(self.name)
        print("Model Parameters :", self.num_params() - self.lm.num_params() if self.lm else self.num_params())
        print(" - Encoder Parameters :", sum([p.numel() for p in self.encoder.parameters()]))
        print(" - Decoder Parameters :", sum([p.numel() for p in self.decoder.parameters()]))
        print(" - Joint Parameters :", sum([p.numel() for p in self.joint_network.parameters()]))

        if self.lm:
            print("LM Parameters :", self.lm.num_params())

        if show_dict:
            for key, value in self.state_dict().items():
                print("{:<64} {:<16} mean {:<16.4f} std {:<16.4f}".format(key, str(tuple(value.size())), value.float().mean(), value.float().std()))