ctc.py 14.3 KB
Newer Older
SWHL's avatar
SWHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
    from ..decoders.ctcdecoder import CTCBeamSearchDecoder  # noqa: F401
    from ..decoders.ctcdecoder import Scorer  # noqa: F401
    from ..decoders.ctcdecoder import \
        ctc_beam_search_decoding_batch  # noqa: F401
    from ..decoders.ctcdecoder import ctc_greedy_decoding  # noqa: F401
except ImportError:
    try:
        from ..decoders.ctcdecoder import CTCBeamSearchDecoder  # noqa: F401
        from ..decoders.ctcdecoder import Scorer  # noqa: F401
        from ..decoders.ctcdecoder import \
            ctc_beam_search_decoding_batch  # noqa: F401
        from ..decoders.ctcdecoder import ctc_greedy_decoding  # noqa: F401
    except Exception as e:
        print("paddlespeech_ctcdecoders not installed!")


class CTCDecoder(object):
    def __init__(self):
        # CTCDecoder LM Score handle
        self._ext_scorer = None
        self.beam_search_decoder = None
        self.blank_id = 0

    def _decode_batch_greedy_offline(self, probs_split, vocab_list):
        """This function will be deprecated in future.
        Decode by best path for a batch of probs matrix input.
        :param probs_split: List of 2-D probability matrix, and each consists
                            of prob vectors for one speech utterancce.
        :param probs_split: List of matrix
        :param vocab_list: List of tokens in the vocabulary, for decoding.
        :type vocab_list: list
        :return: List of transcription texts.
        :rtype: List of str
        """
        results = []
        for i, probs in enumerate(probs_split):
            output_transcription = ctc_greedy_decoding(
                probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id)
            results.append(output_transcription)
        return results

    def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
                         vocab_list):
        """Initialize the external scorer.
        :param beam_alpha: Parameter associated with language model.
        :type beam_alpha: float
        :param beam_beta: Parameter associated with word count.
        :type beam_beta: float
        :param language_model_path: Filepath for language model. If it is
                                    empty, the external scorer will be set to
                                    None, and the decoding method will be pure
                                    beam search without scorer.
        :type language_model_path: str|None
        :param vocab_list: List of tokens in the vocabulary, for decoding.
        :type vocab_list: list
        """
        # init once
        if self._ext_scorer is not None:
            return

        if language_model_path != '':
            self._ext_scorer = Scorer(beam_alpha, beam_beta,
                                      language_model_path, vocab_list)
        else:
            self._ext_scorer = None

    def _decode_batch_beam_search_offline(
            self, probs_split, beam_alpha, beam_beta, beam_size, cutoff_prob,
            cutoff_top_n, vocab_list, num_processes):
        """
        This function will be deprecated in future.
        Decode by beam search for a batch of probs matrix input.
        :param probs_split: List of 2-D probability matrix, and each consists
                            of prob vectors for one speech utterancce.
        :param probs_split: List of matrix
        :param beam_alpha: Parameter associated with language model.
        :type beam_alpha: float
        :param beam_beta: Parameter associated with word count.
        :type beam_beta: float
        :param beam_size: Width for Beam search.
        :type beam_size: int
        :param cutoff_prob: Cutoff probability in pruning,
                            default 1.0, no pruning.
        :type cutoff_prob: float
        :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
                        characters with highest probs in vocabulary will be
                        used in beam search, default 40.
        :type cutoff_top_n: int
        :param vocab_list: List of tokens in the vocabulary, for decoding.
        :type vocab_list: list
        :param num_processes: Number of processes (CPU) for decoder.
        :type num_processes: int
        :return: List of transcription texts.
        :rtype: List of str
        """
        if self._ext_scorer is not None:
            self._ext_scorer.reset_params(beam_alpha, beam_beta)

        # beam search decode
        num_processes = min(num_processes, len(probs_split))
        beam_search_results = ctc_beam_search_decoding_batch(
            probs_split=probs_split,
            vocabulary=vocab_list,
            beam_size=beam_size,
            num_processes=num_processes,
            ext_scoring_func=self._ext_scorer,
            cutoff_prob=cutoff_prob,
            cutoff_top_n=cutoff_top_n,
            blank_id=self.blank_id)

        results = [result[0][1] for result in beam_search_results]
        return results

    def init_decoder(self, batch_size, vocab_list, decoding_method,
                     lang_model_path, beam_alpha, beam_beta, beam_size,
                     cutoff_prob, cutoff_top_n, num_processes):
        """
        init ctc decoders
        Args:
            batch_size(int): Batch size for input data
            vocab_list (list): List of tokens in the vocabulary, for decoding
            decoding_method (str): ctc_beam_search
            lang_model_path (str): language model path
            beam_alpha (float): beam_alpha
            beam_beta (float): beam_beta
            beam_size (int): beam_size
            cutoff_prob (float): cutoff probability in beam search
            cutoff_top_n (int): cutoff_top_n
            num_processes (int): num_processes

        Raises:
            ValueError: when decoding_method not support.

        Returns:
            CTCBeamSearchDecoder
        """
        self.batch_size = batch_size
        self.vocab_list = vocab_list
        self.decoding_method = decoding_method
        self.beam_size = beam_size
        self.cutoff_prob = cutoff_prob
        self.cutoff_top_n = cutoff_top_n
        self.num_processes = num_processes
        if decoding_method == "ctc_beam_search":
            self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
                                  vocab_list)
            if self.beam_search_decoder is None:
                self.beam_search_decoder = self.get_decoder(
                    vocab_list, batch_size, beam_alpha, beam_beta, beam_size,
                    num_processes, cutoff_prob, cutoff_top_n)
            return self.beam_search_decoder
        elif decoding_method == "ctc_greedy":
            self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
                                  vocab_list)
        else:
            raise ValueError(f"Not support: {decoding_method}")

    def decode_probs_offline(self, probs, logits_lens, vocab_list,
                             decoding_method, lang_model_path, beam_alpha,
                             beam_beta, beam_size, cutoff_prob, cutoff_top_n,
                             num_processes):
        """
        This function will be deprecated in future.
        ctc decoding with probs.
        Args:
            probs (Tensor): activation after softmax
            logits_lens (Tensor): audio output lens
            vocab_list (list): List of tokens in the vocabulary, for decoding
            decoding_method (str): ctc_beam_search
            lang_model_path (str): language model path
            beam_alpha (float): beam_alpha
            beam_beta (float): beam_beta
            beam_size (int): beam_size
            cutoff_prob (float): cutoff probability in beam search
            cutoff_top_n (int): cutoff_top_n
            num_processes (int): num_processes

        Raises:
            ValueError: when decoding_method not support.

        Returns:
            List[str]: transcripts.
        """
        logger.warn(
            "This function will be deprecated in future: decode_probs_offline")
        probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)]
        if decoding_method == "ctc_greedy":
            result_transcripts = self._decode_batch_greedy_offline(
                probs_split=probs_split, vocab_list=vocab_list)
        elif decoding_method == "ctc_beam_search":
            result_transcripts = self._decode_batch_beam_search_offline(
                probs_split=probs_split,
                beam_alpha=beam_alpha,
                beam_beta=beam_beta,
                beam_size=beam_size,
                cutoff_prob=cutoff_prob,
                cutoff_top_n=cutoff_top_n,
                vocab_list=vocab_list,
                num_processes=num_processes)
        else:
            raise ValueError(f"Not support: {decoding_method}")
        return result_transcripts

    def get_decoder(self, vocab_list, batch_size, beam_alpha, beam_beta,
                    beam_size, num_processes, cutoff_prob, cutoff_top_n):
        """
        init get ctc decoder
        Args:
            vocab_list (list): List of tokens in the vocabulary, for decoding.
            batch_size(int): Batch size for input data
            beam_alpha (float): beam_alpha
            beam_beta (float): beam_beta
            beam_size (int): beam_size
            num_processes (int): num_processes
            cutoff_prob (float): cutoff probability in beam search
            cutoff_top_n (int): cutoff_top_n

        Raises:
            ValueError: when decoding_method not support.

        Returns:
            CTCBeamSearchDecoder
        """
        num_processes = min(num_processes, batch_size)
        if self._ext_scorer is not None:
            self._ext_scorer.reset_params(beam_alpha, beam_beta)
        if self.decoding_method == "ctc_beam_search":
            beam_search_decoder = CTCBeamSearchDecoder(
                vocab_list, batch_size, beam_size, num_processes, cutoff_prob,
                cutoff_top_n, self._ext_scorer, self.blank_id)
        else:
            raise ValueError(f"Not support: {self.decoding_method}")
        return beam_search_decoder

    def next(self, probs, logits_lens):
        """
        Input probs into ctc decoder
        Args:
            probs (list(list(float))): probs for a batch of data
            logits_lens (list(int)): logits lens for a batch of data
        Raises:
            Exception: when the ctc decoder is not initialized
            ValueError: when decoding_method not support.
        """

        if self.beam_search_decoder is None:
            raise Exception(
                "You need to initialize the beam_search_decoder firstly")
        beam_search_decoder = self.beam_search_decoder

        has_value = (logits_lens > 0).tolist()
        has_value = [
            "true" if has_value[i] is True else "false"
            for i in range(len(has_value))
        ]
        probs_split = [
            probs[i, :l, :].tolist() if has_value[i] else probs[i].tolist()
            for i, l in enumerate(logits_lens)
        ]
        if self.decoding_method == "ctc_beam_search":
            beam_search_decoder.next(probs_split, has_value)
        else:
            raise ValueError(f"Not support: {self.decoding_method}")

        return

    def decode(self):
        """
        Get the decoding result
        Raises:
            Exception: when the ctc decoder is not initialized
            ValueError: when decoding_method not support.
        Returns:
            results_best (list(str)): The best result for a batch of data
            results_beam (list(list(str))): The beam search result for a batch of data
        """
        if self.beam_search_decoder is None:
            raise Exception(
                "You need to initialize the beam_search_decoder firstly")

        beam_search_decoder = self.beam_search_decoder
        if self.decoding_method == "ctc_beam_search":
            batch_beam_results = beam_search_decoder.decode()
            batch_beam_results = [[(res[0], res[1]) for res in beam_results]
                                  for beam_results in batch_beam_results]
            results_best = [result[0][1] for result in batch_beam_results]
            results_beam = [[trans[1] for trans in result]
                            for result in batch_beam_results]

        else:
            raise ValueError(f"Not support: {self.decoding_method}")

        return results_best, results_beam

    def reset_decoder(self,
                      batch_size=-1,
                      beam_size=-1,
                      num_processes=-1,
                      cutoff_prob=-1.0,
                      cutoff_top_n=-1):
        if batch_size > 0:
            self.batch_size = batch_size
        if beam_size > 0:
            self.beam_size = beam_size
        if num_processes > 0:
            self.num_processes = num_processes
        if cutoff_prob > 0:
            self.cutoff_prob = cutoff_prob
        if cutoff_top_n > 0:
            self.cutoff_top_n = cutoff_top_n
        """
        Reset the decoder state
        Args:
            batch_size(int): Batch size for input data
            beam_size (int): beam_size
            num_processes (int): num_processes
            cutoff_prob (float): cutoff probability in beam search
            cutoff_top_n (int): cutoff_top_n
        Raises:
            Exception: when the ctc decoder is not initialized
        """
        if self.beam_search_decoder is None:
            raise Exception(
                "You need to initialize the beam_search_decoder firstly")
        self.beam_search_decoder.reset_state(
            self.batch_size, self.beam_size, self.num_processes,
            self.cutoff_prob, self.cutoff_top_n)

    def del_decoder(self):
        """
        Delete the decoder
        """
        if self.beam_search_decoder is not None:
            del self.beam_search_decoder
            self.beam_search_decoder = None