rnnt.py 15.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os

import torch
from torch import Tensor
from typing import Tuple, Optional
from . mutual_information import mutual_information_recursion, joint_mutual_information_recursion



def get_rnnt_logprobs(lm: Tensor,
                      am: Tensor,
                      symbols: Tensor,
                      termination_symbol: int) -> Tuple[Tensor, Tensor]:
    """
    Reduces RNN-T problem (the simple case, where joiner network is just addition),
    to a compact, standard form that can then be given
    (with boundaries) to mutual_information_recursion().  This function is called from
    rnnt_loss_simple(), but may be useful for other purposes.

    Args:
         lm:  Language model part of un-normalized logprobs of symbols, to be added to
              acoustic model part before normalizing.  Of shape:
                 [B][S+1][C]
              where B is the batch size, S is the maximum sequence length of
              the symbol sequence, possibly including the EOS symbol; and
              C is size of the symbol vocabulary, including the termination/next-frame
              symbol.
              Conceptually, lm[b][s] is a vector of length [C] representing the
              "language model" part of the un-normalized logprobs of symbols,
              given all symbols *earlier than* s in the sequence.  The reason
              we still need this for position S is that we may still be emitting
              the termination/next-frame symbol at this point.
         am:  Acoustic-model part of un-normalized logprobs of symbols, to be added
              to language-model part before normalizing.  Of shape:
                 [B][T][C]
              where B is the batch size, T is the maximum sequence length of
              the acoustic sequences (in frames); and C is size of the symbol
              vocabulary, including the termination/next-frame symbol.  It reflects
              the "acoustic" part of the probability of any given symbol appearing
              next on this frame.
          symbols: A LongTensor of shape [B][S], containing the symbols at each position
              of the sequence, possibly including EOS
          termination_symbol: The identity of the termination symbol, must be
               in {0..C-1}
    Returns: (px, py) (the names are quite arbitrary).
              px: logprobs, of shape [B][S][T+1]
              py: logprobs, of shape [B][S+1][T]
          in the recursion:
             p[b,0,0] = 0.0
             p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
                                p[b,s,t-1] + py[b,s,t-1])
          .. where p[b][s][t] is the "joint score" of the pair of subsequences of
          length s and t respectively.  px[b][s][t] represents the probability of
          extending the subsequences of length (s,t) by one in the s direction,
          given the particular symbol, and py[b][s][t] represents the probability
          of extending the subsequences of length (s,t) by one in the t direction,
          i.e. of emitting the termination/next-frame symbol.

          px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
          we cannot emit any symbols.  This is simply a way of incorporating
          the probability of the termination symbol on the last frame.
    """
    assert lm.ndim== 3 and am.ndim == 3 and lm.shape[0] == am.shape[0] and lm.shape[2] == am.shape[2]
    (B, T, C) = am.shape
    S = lm.shape[1] - 1
    assert symbols.shape == (B, S)

    # subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
    # without causing underflow or overflow.
    am_max, _ = torch.max(am, dim=2, keepdim=True)  # am_max: [B][T][1]
    lm_max, _ = torch.max(lm, dim=2, keepdim=True)  # lm_max: [B][S+1][1]
    am_probs = (am - am_max).exp()
    lm_probs = (lm - lm_max).exp()
    # normalizers: [B][S+1][T]
    normalizers = (torch.matmul(lm_probs, am_probs.transpose(1, 2)) + 1.0e-20).log()

    # add lm_max and am_max to normalizers, to make it as if we had not
    # subtracted am_max and lm_max above.
    normalizers = normalizers + lm_max + am_max.transpose(1, 2)  # [B][S+1][T]

    # px is the probs of the actual symbols..
    px_am = torch.gather(am.unsqueeze(1).expand(B, S, T, C), dim=3,
                         index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1)).squeeze(-1) # [B][S][T]
    px_am = torch.cat((px_am,
                       torch.full((B, S, 1), float('-inf'),
                                  device=px_am.device, dtype=px_am.dtype)),
                      dim=2)  # now: [B][S][T+1], index [:,:,T] has -inf..

    px_lm = torch.gather(lm[:,:S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]

    px = px_am + px_lm  # [B][S][T+1], last slice indexed [:,:,T] is -inf
    px[:,:,:T] -= normalizers[:,:S,:] # px: [B][S][T+1]

    # py is the probs of termination symbols, of shape [B][S+1][T]
    py_am = am[:,:,termination_symbol].unsqueeze(1) # [B][1][T]
    py_lm = lm[:,:,termination_symbol].unsqueeze(2) # [B][S+1][1]
    py = py_am + py_lm - normalizers

    return (px, py)


def rnnt_loss_simple(lm: Tensor,
                     am: Tensor,
                     symbols: Tensor,
                     termination_symbol: int,
                     boundary: Tensor = None) -> Tensor:
    """
    A simple case of the RNN-T loss, where the 'joiner' network is just addition.
109
    Returns negated total loss value.
110
111
112
113
114
115
116
117
118
119
120
121
122

    Args:
     lm: language-model part of unnormalized log-probs of symbols, with shape
        (B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes
     am: acoustic-model part of unnormalized log-probs of symbols, with shape
       (B, T, C), i.e. batch, frame, num_classes
     symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
     termination_symbol: the termination symbol, with 0 <= termination_symbol < C
     boundary: a LongTensor of shape [B, 4] with elements interpreted as
        [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
        if boundary is not supplied.
        Most likely you will want begin_symbol and begin_frame to be zero.
   Returns:
123
124
      a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
      for each element of the batch (like log-probs of sequences).
125
126
127
    """
    px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol)
    return mutual_information_recursion(px, py, boundary)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304


def get_rnnt_logprobs_aux(lm: Tensor,
                          am: Tensor,
                          symbols: Tensor,
                          termination_symbol: int,
                          lm_only_scale: float = 0.1,
                          am_only_scale: float = 0.1) -> Tuple[Tensor, Tensor]:
    """
    Reduces RNN-T problem (the simple case, where joiner network is just addition),
    to a compact, standard form that can then be given
    (with boundaries) to mutual_information_recursion().   This version allows you
    to make the loss-function one of the form:
          lm_only_scale * lm_probs +
          am_only_scale * am_probs +
          (1-lm_only_scale-am_only_scale) * combined_probs
    where lm_probs and am_probs are the probabilities given the lm and acoustic model
    independently.

    This function is called from
    rnnt_loss_aux(), but may be useful for other purposes.

    Args:
         lm:  Language model part of un-normalized logprobs of symbols, to be added to
              acoustic model part before normalizing.  Of shape:
                 [B][S+1][C]
              where B is the batch size, S is the maximum sequence length of
              the symbol sequence, possibly including the EOS symbol; and
              C is size of the symbol vocabulary, including the termination/next-frame
              symbol.
              Conceptually, lm[b][s] is a vector of length [C] representing the
              "language model" part of the un-normalized logprobs of symbols,
              given all symbols *earlier than* s in the sequence.  The reason
              we still need this for position S is that we may still be emitting
              the termination/next-frame symbol at this point.
         am:  Acoustic-model part of un-normalized logprobs of symbols, to be added
              to language-model part before normalizing.  Of shape:
                 [B][T][C]
              where B is the batch size, T is the maximum sequence length of
              the acoustic sequences (in frames); and C is size of the symbol
              vocabulary, including the termination/next-frame symbol.  It reflects
              the "acoustic" part of the probability of any given symbol appearing
              next on this frame.
          symbols: A LongTensor of shape [B][S], containing the symbols at each position
              of the sequence, possibly including EOS
          termination_symbol: The identity of the termination symbol, must be
               in {0..C-1}
    Returns: (px, py) (the names are quite arbitrary).
              px: logprobs, of shape [B][S][T+1]
              py: logprobs, of shape [B][S+1][T]
          in the recursion:
             p[b,0,0] = 0.0
             p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
                                p[b,s,t-1] + py[b,s,t-1])
          .. where p[b][s][t] is the "joint score" of the pair of subsequences of
          length s and t respectively.  px[b][s][t] represents the probability of
          extending the subsequences of length (s,t) by one in the s direction,
          given the particular symbol, and py[b][s][t] represents the probability
          of extending the subsequences of length (s,t) by one in the t direction,
          i.e. of emitting the termination/next-frame symbol.

          px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
          we cannot emit any symbols.  This is simply a way of incorporating
          the probability of the termination symbol on the last frame.
    """
    assert lm.ndim== 3 and am.ndim == 3 and lm.shape[0] == am.shape[0] and lm.shape[2] == am.shape[2]
    (B, T, C) = am.shape
    S = lm.shape[1] - 1
    assert symbols.shape == (B, S)

    # Caution: some parts of this code are a little less clear than they could
    # be due to optimizations.  In particular it may not be totally obvious that
    # all of the logprobs here are properly normalized.  We test that
    # this code is invariant to adding constants in the appropriate ways.

    # subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
    # without causing underflow or overflow.
    am_max, _ = torch.max(am, dim=2, keepdim=True)  # am_max: [B][T][1]
    lm_max, _ = torch.max(lm, dim=2, keepdim=True)  # lm_max: [B][S+1][1]
    am_probs = (am - am_max).exp()  # [B][T][C]
    lm_probs = (lm - lm_max).exp()  # [B][S+1][C]
    # normalizers: [B][S+1][T]
    normalizers = (torch.matmul(lm_probs, am_probs.transpose(1, 2)) + 1.0e-20).log()

    # normalizer per frame, if we take only the LM probs by themselves
    lmonly_normalizers = lm_probs.sum(dim=2, keepdim=True) # lmonly_normalizers: [B][S+1][1]
    unigram_lm = torch.mean(lm_probs / lmonly_normalizers, dim=(0,1), keepdim=True) + 1.0e-20 # [1][1][C]
    amonly_normalizers = torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C)).reshape(B, T, 1).log() + am_max # [B][T][1]
    amonly_normalizers = amonly_normalizers.transpose(1, 2)  # [B][1][T]
    unigram_lm = unigram_lm.log()
    lmonly_normalizers = lmonly_normalizers.log() + lm_max # [B][S+1][1], log-normalizer, used for LM-only part of prob.


    # add lm_max and am_max to normalizers, to make it as if we had not
    # subtracted am_max and lm_max above.
    normalizers = normalizers + lm_max + am_max.transpose(1, 2)  # [B][S+1][T]

    # px is the probs of the actual symbols (not yet normalized)..
    px_am = torch.gather(am.unsqueeze(1).expand(B, S, T, C), dim=3,
                         index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1)).squeeze(-1) # [B][S][T]
    px_am = torch.cat((px_am,
                       torch.full((B, S, 1), float('-inf'),
                                  device=px_am.device, dtype=px_am.dtype)),
                      dim=2)  # now: [B][S][T+1], index [:,:,T] has -inf..


    px_lm = torch.gather(lm[:,:S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
    px_lm_unigram = torch.gather(unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]

    px = px_am + px_lm  # [B][S][T+1], last slice indexed [:,:,T] is -inf
    px[:,:,:T] -= normalizers[:,:S,:] # px: [B][S][T+1]

    px_amonly = px_am + px_lm_unigram     # [B][S][T+1]
    px_amonly[:,:,:T] -= amonly_normalizers
    px_lmonly = px_lm - lmonly_normalizers[:,:S,:]


    # py is the probs of termination symbols, of shape [B][S+1][T]
    py_am = am[:,:,termination_symbol].unsqueeze(1) # [B][1][T]
    py_lm = lm[:,:,termination_symbol].unsqueeze(2) # [B][S+1][1]
    py = py_am + py_lm - normalizers

    py_lm_unigram = unigram_lm[0][0][termination_symbol] # scalar, normalized..
    py_amonly = py_am + py_lm_unigram - amonly_normalizers # [B][S+1][T]
    py_lmonly = py_lm - lmonly_normalizers # [B][S+1][T]


    combined_scale = 1.0 - lm_only_scale - am_only_scale

    # We need to avoid exact zeros in the scales because otherwise multiplying -inf
    # by zero generates nan.
    if lm_only_scale == 0.0:
        lm_only_scale = 1.0e-20
    if am_only_scale == 0.0:
        am_only_scale = 1.0e-20

    px_interp = px * combined_scale + px_lmonly * lm_only_scale + px_amonly * am_only_scale
    py_interp = py * combined_scale + py_lmonly * lm_only_scale + py_amonly * am_only_scale

    print("px_interp = ", px_interp)
    print("py_interp = ", py_interp)
    return (px_interp, py_interp)


def rnnt_loss_aux(lm: Tensor,
                  am: Tensor,
                  symbols: Tensor,
                  termination_symbol: int,
                  lm_only_scale: float = 0.1,
                  am_only_scale: float = 0.1,
                  boundary: Tensor = None) -> Tensor:
    """
    A simple case of the RNN-T loss, where the 'joiner' network is just addition.
    Returns negated total loss value.

    Args:
     lm: language-model part of unnormalized log-probs of symbols, with shape
        (B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes.
        These are assumed to be well-normalized, in the sense that we could
        use them as probabilities separately from the am scores
     am: acoustic-model part of unnormalized log-probs of symbols, with shape
       (B, T, C), i.e. batch, frame, num_classes
     symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
     termination_symbol: the termination symbol, with 0 <= termination_symbol < C
     am_only_scale: the scale on the "AM-only" part of the loss, for which we use
       an "averaged" LM (averaged over all histories, so effectively unigram).
     boundary: a LongTensor of shape [B, 4] with elements interpreted as
        [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
        if boundary is not supplied.
        Most likely you will want begin_symbol and begin_frame to be zero.
   Returns:
      a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
      for each element of the batch (like log-probs of sequences).
    """
    px, py = get_rnnt_logprobs_aux(lm, am, symbols, termination_symbol,
                                   lm_only_scale, am_only_scale)
    return mutual_information_recursion(px, py, boundary)