Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
This source diff could not be displayed because it is too large. You can view the blob instead.
/LocalData/dataset/LibriSpeech
train-clean-100/103/1240/103-1240-0000.flac 225360
train-clean-100/103/1240/103-1240-0001.flac 255120
train-clean-100/103/1240/103-1240-0002.flac 223120
train-clean-100/103/1240/103-1240-0003.flac 235360
train-clean-100/103/1240/103-1240-0004.flac 200240
train-clean-100/103/1240/103-1240-0005.flac 242800
train-clean-100/103/1240/103-1240-0006.flac 153280
train-clean-100/103/1240/103-1240-0007.flac 240560
train-clean-100/103/1240/103-1240-0008.flac 246960
train-clean-100/103/1240/103-1240-0009.flac 160480
train-clean-100/103/1240/103-1240-0010.flac 236880
train-clean-100/103/1240/103-1240-0011.flac 234480
train-clean-100/103/1240/103-1240-0012.flac 243040
train-clean-100/103/1240/103-1240-0013.flac 244160
train-clean-100/103/1240/103-1240-0014.flac 223360
train-clean-100/103/1240/103-1240-0015.flac 60960
train-clean-100/103/1240/103-1240-0016.flac 250640
train-clean-100/103/1240/103-1240-0017.flac 229040
train-clean-100/103/1240/103-1240-0018.flac 185760
train-clean-100/103/1240/103-1240-0019.flac 246480
train-clean-100/103/1240/103-1240-0020.flac 214640
train-clean-100/103/1240/103-1240-0021.flac 236960
train-clean-100/103/1240/103-1240-0022.flac 262000
train-clean-100/103/1240/103-1240-0023.flac 194400
train-clean-100/103/1240/103-1240-0024.flac 244320
train-clean-100/103/1240/103-1240-0025.flac 241920
train-clean-100/103/1240/103-1240-0026.flac 133360
train-clean-100/103/1240/103-1240-0027.flac 223440
train-clean-100/103/1240/103-1240-0028.flac 250400
train-clean-100/103/1240/103-1240-0029.flac 244320
train-clean-100/103/1240/103-1240-0030.flac 232320
train-clean-100/103/1240/103-1240-0031.flac 269760
train-clean-100/103/1240/103-1240-0032.flac 236400
train-clean-100/103/1240/103-1240-0033.flac 230640
train-clean-100/103/1240/103-1240-0034.flac 246480
train-clean-100/103/1240/103-1240-0035.flac 256720
train-clean-100/103/1240/103-1240-0036.flac 200320
train-clean-100/103/1240/103-1240-0037.flac 237040
train-clean-100/103/1240/103-1240-0038.flac 114480
train-clean-100/103/1240/103-1240-0039.flac 230800
train-clean-100/103/1240/103-1240-0040.flac 234720
train-clean-100/103/1240/103-1240-0041.flac 216160
train-clean-100/103/1240/103-1240-0042.flac 249680
train-clean-100/103/1240/103-1240-0043.flac 236160
train-clean-100/103/1240/103-1240-0044.flac 262240
train-clean-100/103/1240/103-1240-0045.flac 250800
train-clean-100/103/1240/103-1240-0046.flac 222800
train-clean-100/103/1240/103-1240-0047.flac 206320
train-clean-100/103/1240/103-1240-0048.flac 236320
train-clean-100/103/1240/103-1240-0049.flac 244560
train-clean-100/103/1240/103-1240-0050.flac 224400
train-clean-100/103/1240/103-1240-0051.flac 245760
train-clean-100/103/1240/103-1240-0052.flac 236640
train-clean-100/103/1240/103-1240-0053.flac 218640
train-clean-100/103/1240/103-1240-0054.flac 261360
train-clean-100/103/1240/103-1240-0055.flac 179920
train-clean-100/103/1240/103-1240-0056.flac 229040
train-clean-100/103/1240/103-1240-0057.flac 109680
train-clean-100/103/1241/103-1241-0000.flac 255440
train-clean-100/103/1241/103-1241-0001.flac 248800
train-clean-100/103/1241/103-1241-0002.flac 249040
train-clean-100/103/1241/103-1241-0003.flac 222160
train-clean-100/103/1241/103-1241-0004.flac 236080
train-clean-100/103/1241/103-1241-0005.flac 224400
train-clean-100/103/1241/103-1241-0006.flac 243760
train-clean-100/103/1241/103-1241-0007.flac 242320
train-clean-100/103/1241/103-1241-0008.flac 242160
train-clean-100/103/1241/103-1241-0009.flac 222400
train-clean-100/103/1241/103-1241-0010.flac 253920
train-clean-100/103/1241/103-1241-0011.flac 231760
train-clean-100/103/1241/103-1241-0012.flac 239680
train-clean-100/103/1241/103-1241-0013.flac 236960
train-clean-100/103/1241/103-1241-0014.flac 242080
train-clean-100/103/1241/103-1241-0015.flac 224160
train-clean-100/103/1241/103-1241-0016.flac 234640
train-clean-100/103/1241/103-1241-0017.flac 254240
train-clean-100/103/1241/103-1241-0018.flac 150960
train-clean-100/103/1241/103-1241-0019.flac 48400
train-clean-100/103/1241/103-1241-0020.flac 155360
train-clean-100/103/1241/103-1241-0021.flac 242880
train-clean-100/103/1241/103-1241-0022.flac 261600
train-clean-100/103/1241/103-1241-0023.flac 266720
train-clean-100/103/1241/103-1241-0024.flac 254240
train-clean-100/103/1241/103-1241-0025.flac 77280
train-clean-100/103/1241/103-1241-0026.flac 176080
train-clean-100/103/1241/103-1241-0027.flac 238080
train-clean-100/103/1241/103-1241-0028.flac 248880
train-clean-100/103/1241/103-1241-0029.flac 244960
train-clean-100/103/1241/103-1241-0030.flac 247520
train-clean-100/103/1241/103-1241-0031.flac 209600
train-clean-100/103/1241/103-1241-0032.flac 224080
train-clean-100/103/1241/103-1241-0033.flac 251920
train-clean-100/103/1241/103-1241-0034.flac 270560
train-clean-100/103/1241/103-1241-0035.flac 248800
train-clean-100/103/1241/103-1241-0036.flac 249040
train-clean-100/103/1241/103-1241-0037.flac 204400
train-clean-100/103/1241/103-1241-0038.flac 238960
train-clean-100/103/1241/103-1241-0039.flac 258160
train-clean-100/103/1241/103-1241-0040.flac 220560
train-clean-100/103/1241/103-1241-0041.flac 252240
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
10 10
11 11
12 12
13 13
14 14
15 15
16 16
17 17
18 18
19 19
20 20
21 21
22 22
23 23
24 24
25 25
26 26
27 27
28 28
29 29
30 30
31 31
32 32
33 33
34 34
35 35
36 36
37 37
38 38
39 39
40 40
41 41
42 42
43 43
44 44
45 45
46 46
47 47
48 48
49 49
50 50
51 51
52 52
53 53
54 54
55 55
56 56
57 57
58 58
59 59
60 60
61 61
62 62
63 63
64 64
65 65
66 66
67 67
68 68
69 69
70 70
71 71
72 72
73 73
74 74
75 75
76 76
77 77
78 78
79 79
80 80
81 81
82 82
83 83
84 84
85 85
86 86
87 87
88 88
89 89
90 90
91 91
92 92
93 93
94 94
95 95
96 96
97 97
98 98
99 99
100 100
101 101
102 102
103 103
104 104
105 105
106 106
107 107
108 108
109 109
110 110
111 111
112 112
113 113
114 114
115 115
116 116
117 117
118 118
119 119
120 120
121 121
122 122
123 123
124 124
125 125
126 126
127 127
128 128
129 129
130 130
131 131
132 132
133 133
134 134
135 135
136 136
137 137
138 138
139 139
140 140
141 141
142 142
143 143
144 144
145 145
146 146
147 147
148 148
149 149
150 150
151 151
152 152
153 153
154 154
155 155
156 156
157 157
158 158
159 159
160 160
161 161
162 162
163 163
164 164
165 165
166 166
167 167
168 168
169 169
170 170
171 171
172 172
173 173
174 174
175 175
176 176
177 177
178 178
179 179
180 180
181 181
182 182
183 183
184 184
185 185
186 186
187 187
188 188
189 189
190 190
191 191
192 192
193 193
194 194
195 195
196 196
197 197
198 198
199 199
200 200
201 201
202 202
203 203
204 204
205 205
206 206
207 207
208 208
209 209
210 210
211 211
212 212
213 213
214 214
215 215
216 216
217 217
218 218
219 219
220 220
221 221
222 222
223 223
224 224
225 225
226 226
227 227
228 228
229 229
230 230
231 231
232 232
233 233
234 234
235 235
236 236
237 237
238 238
239 239
240 240
241 241
242 242
243 243
244 244
245 245
246 246
247 247
248 248
249 249
250 250
251 251
252 252
253 253
254 254
255 255
256 256
257 257
258 258
259 259
260 260
261 261
262 262
263 263
264 264
265 265
266 266
267 267
268 268
269 269
270 270
271 271
272 272
273 273
274 274
275 275
276 276
277 277
278 278
279 279
280 280
281 281
282 282
283 283
284 284
285 285
286 286
287 287
288 288
289 289
290 290
291 291
292 292
293 293
294 294
295 295
296 296
297 297
298 298
299 299
300 300
301 301
302 302
303 303
304 304
305 305
306 306
307 307
308 308
309 309
310 310
311 311
312 312
313 313
314 314
315 315
316 316
317 317
318 318
319 319
320 320
321 321
322 322
323 323
324 324
325 325
326 326
327 327
328 328
329 329
330 330
331 331
332 332
333 333
334 334
335 335
336 336
337 337
338 338
339 339
340 340
341 341
342 342
343 343
344 344
345 345
346 346
347 347
348 348
349 349
350 350
351 351
352 352
353 353
354 354
355 355
356 356
357 357
358 358
359 359
360 360
361 361
362 362
363 363
This source diff could not be displayed because it is too large. You can view the blob instead.
/LocalData/dataset/LibriSpeech
train-clean-100/103/1240/103-1240-0000.flac 225360
train-clean-100/103/1240/103-1240-0001.flac 255120
train-clean-100/103/1240/103-1240-0002.flac 223120
train-clean-100/103/1240/103-1240-0003.flac 235360
train-clean-100/103/1240/103-1240-0004.flac 200240
train-clean-100/103/1240/103-1240-0005.flac 242800
train-clean-100/103/1240/103-1240-0006.flac 153280
train-clean-100/103/1240/103-1240-0007.flac 240560
train-clean-100/103/1240/103-1240-0008.flac 246960
train-clean-100/103/1240/103-1240-0009.flac 160480
train-clean-100/103/1240/103-1240-0010.flac 236880
train-clean-100/103/1240/103-1240-0011.flac 234480
train-clean-100/103/1240/103-1240-0012.flac 243040
train-clean-100/103/1240/103-1240-0013.flac 244160
train-clean-100/103/1240/103-1240-0014.flac 223360
train-clean-100/103/1240/103-1240-0015.flac 60960
train-clean-100/103/1240/103-1240-0016.flac 250640
train-clean-100/103/1240/103-1240-0017.flac 229040
train-clean-100/103/1240/103-1240-0018.flac 185760
train-clean-100/103/1240/103-1240-0019.flac 246480
train-clean-100/103/1240/103-1240-0020.flac 214640
train-clean-100/103/1240/103-1240-0021.flac 236960
train-clean-100/103/1240/103-1240-0022.flac 262000
train-clean-100/103/1240/103-1240-0023.flac 194400
train-clean-100/103/1240/103-1240-0024.flac 244320
train-clean-100/103/1240/103-1240-0025.flac 241920
train-clean-100/103/1240/103-1240-0026.flac 133360
train-clean-100/103/1240/103-1240-0027.flac 223440
train-clean-100/103/1240/103-1240-0028.flac 250400
train-clean-100/103/1240/103-1240-0029.flac 244320
train-clean-100/103/1240/103-1240-0030.flac 232320
train-clean-100/103/1240/103-1240-0031.flac 269760
train-clean-100/103/1240/103-1240-0032.flac 236400
train-clean-100/103/1240/103-1240-0033.flac 230640
train-clean-100/103/1240/103-1240-0034.flac 246480
train-clean-100/103/1240/103-1240-0035.flac 256720
train-clean-100/103/1240/103-1240-0036.flac 200320
train-clean-100/103/1240/103-1240-0037.flac 237040
train-clean-100/103/1240/103-1240-0038.flac 114480
train-clean-100/103/1240/103-1240-0039.flac 230800
train-clean-100/103/1240/103-1240-0040.flac 234720
train-clean-100/103/1240/103-1240-0041.flac 216160
train-clean-100/103/1240/103-1240-0042.flac 249680
train-clean-100/103/1240/103-1240-0043.flac 236160
train-clean-100/103/1240/103-1240-0044.flac 262240
train-clean-100/103/1240/103-1240-0045.flac 250800
train-clean-100/103/1240/103-1240-0046.flac 222800
train-clean-100/103/1240/103-1240-0047.flac 206320
train-clean-100/103/1240/103-1240-0048.flac 236320
train-clean-100/103/1240/103-1240-0049.flac 244560
train-clean-100/103/1240/103-1240-0050.flac 224400
train-clean-100/103/1240/103-1240-0051.flac 245760
train-clean-100/103/1240/103-1240-0052.flac 236640
train-clean-100/103/1240/103-1240-0053.flac 218640
train-clean-100/103/1240/103-1240-0054.flac 261360
train-clean-100/103/1240/103-1240-0055.flac 179920
train-clean-100/103/1240/103-1240-0056.flac 229040
train-clean-100/103/1240/103-1240-0057.flac 109680
train-clean-100/103/1241/103-1241-0000.flac 255440
train-clean-100/103/1241/103-1241-0001.flac 248800
train-clean-100/103/1241/103-1241-0002.flac 249040
train-clean-100/103/1241/103-1241-0003.flac 222160
train-clean-100/103/1241/103-1241-0004.flac 236080
train-clean-100/103/1241/103-1241-0005.flac 224400
train-clean-100/103/1241/103-1241-0006.flac 243760
train-clean-100/103/1241/103-1241-0007.flac 242320
train-clean-100/103/1241/103-1241-0008.flac 242160
train-clean-100/103/1241/103-1241-0009.flac 222400
train-clean-100/103/1241/103-1241-0010.flac 253920
train-clean-100/103/1241/103-1241-0011.flac 231760
train-clean-100/103/1241/103-1241-0012.flac 239680
train-clean-100/103/1241/103-1241-0013.flac 236960
train-clean-100/103/1241/103-1241-0014.flac 242080
train-clean-100/103/1241/103-1241-0015.flac 224160
train-clean-100/103/1241/103-1241-0016.flac 234640
train-clean-100/103/1241/103-1241-0017.flac 254240
train-clean-100/103/1241/103-1241-0018.flac 150960
train-clean-100/103/1241/103-1241-0019.flac 48400
train-clean-100/103/1241/103-1241-0020.flac 155360
train-clean-100/103/1241/103-1241-0021.flac 242880
train-clean-100/103/1241/103-1241-0022.flac 261600
train-clean-100/103/1241/103-1241-0023.flac 266720
train-clean-100/103/1241/103-1241-0024.flac 254240
train-clean-100/103/1241/103-1241-0025.flac 77280
train-clean-100/103/1241/103-1241-0026.flac 176080
train-clean-100/103/1241/103-1241-0027.flac 238080
train-clean-100/103/1241/103-1241-0028.flac 248880
train-clean-100/103/1241/103-1241-0029.flac 244960
train-clean-100/103/1241/103-1241-0030.flac 247520
train-clean-100/103/1241/103-1241-0031.flac 209600
train-clean-100/103/1241/103-1241-0032.flac 224080
train-clean-100/103/1241/103-1241-0033.flac 251920
train-clean-100/103/1241/103-1241-0034.flac 270560
train-clean-100/103/1241/103-1241-0035.flac 248800
train-clean-100/103/1241/103-1241-0036.flac 249040
train-clean-100/103/1241/103-1241-0037.flac 204400
train-clean-100/103/1241/103-1241-0038.flac 238960
train-clean-100/103/1241/103-1241-0039.flac 258160
train-clean-100/103/1241/103-1241-0040.flac 220560
train-clean-100/103/1241/103-1241-0041.flac 252240
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
"""
We just merge all the required modules and functions into one python file.
It is for easily use the pre-trained model to extract features.
"""
import math
import numpy as np
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch import Tensor
from typing import Any, Dict, List, Tuple, Callable, Optional
logger = logging.getLogger(__name__)
# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
if module_name == "TransformerEncoderBase":
return "TransformerEncoder"
else:
return module_name
def utils_make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
def utils_item(tensor):
# tpu-comment: making this a no-op for xla devices.
if torch.is_tensor(tensor) and tensor.device.type == "xla":
return tensor.detach()
if hasattr(tensor, "item"):
return tensor.item()
if hasattr(tensor, "__getitem__"):
return tensor[0]
return tensor
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
"""
Helper to wrap layers/modules in FSDP. This falls back to a no-op if
fairscale is not available.
Args:
module (nn.Module): module to (maybe) wrap
min_num_params (int, Optional): minimum number of layer params to wrap
"""
try:
from fairscale.nn import wrap
if min_num_params is not None:
num_params = sum(p.numel() for p in module.parameters())
if num_params >= min_num_params:
return wrap(module, **kwargs)
else:
return module
else:
return wrap(module, **kwargs)
except ImportError:
return module
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
def relu_squared(x: torch.Tensor):
return F.relu(x).pow(2)
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "relu_squared":
return relu_squared
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
logger.warn(
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
)
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def softmax(x, dim: int, onnx_trace: bool = False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:
return F.softmax(x, dim=dim, dtype=torch.float32)
def compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
lengths = np.random.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = np.random.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - keep_length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = np.fromiter(
(e - s if e - s >= length + min_space else 0 for s, e in parts),
np.int,
)
l_sum = np.sum(lens)
if l_sum == 0:
break
probs = lens / np.sum(lens)
c = np.random.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray(
[
mask_idc[j] + offset
for j in range(len(mask_idc))
for offset in range(lengths[j])
]
)
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len and require_same_masks:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
if mask_dropout > 0:
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
mask_idc = np.random.choice(
mask_idc, len(mask_idc) - num_holes, replace=False
)
mask[i, mask_idc] = True
return mask
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
def pad_to_multiple(x, multiple, dim=-1, value=0):
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
if x is None:
return None, 0
tsz = x.size(dim)
m = tsz / multiple
remainder = math.ceil(m) * multiple - tsz
if m.is_integer():
return x, 0
pad_offset = (0,) * (-1 - dim) * 2
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
def is_xla_tensor(tensor):
return torch.is_tensor(tensor) and tensor.device.type == "xla"
def index_put(tensor, indices, value):
if is_xla_tensor(tensor):
for _ in range(indices.dim(), tensor.dim()):
indices = indices.unsqueeze(-1)
if indices.size(-1) < tensor.size(-1):
indices = indices.expand_as(tensor)
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
else:
tensor[indices] = value
return tensor
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
learned: bool = False,
):
if learned:
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside
# LearnedPositionalEmbedding. Move this there for a cleaner implementation.
if padding_idx is not None:
num_embeddings = num_embeddings + padding_idx + 1
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings + padding_idx + 1,
)
return m
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
if torch.jit.is_scripting() or torch.jit.is_tracing():
export = True
if not export and torch.cuda.is_available() and has_fused_layernorm:
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class TransformerEncoderBase(nn.Module):
"""
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary: deprecated(None)
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0):
self.cfg = cfg
super().__init__()
self.register_buffer("version", torch.Tensor([3]))
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
)
self.encoder_layerdrop = cfg.encoder.layerdrop
embed_dim = embed_tokens.embedding_dim if embed_tokens is not None else cfg.encoder.embed_dim
self.padding_idx = embed_tokens.padding_idx if embed_tokens is not None else 1
self.max_source_positions = cfg.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
cfg.max_source_positions,
embed_dim,
self.padding_idx,
learned=cfg.encoder.learned_pos,
)
if not cfg.no_token_positional_embeddings
else None
)
if cfg.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
else:
self.layernorm_embedding = None
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
self.quant_noise = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=False),
cfg.quant_noise.pq,
cfg.quant_noise.pq_block_size,
)
else:
self.quant_noise = None
if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.use_rel_pos_enc = use_rel_pos_enc
self.scaling_for_att = scaling_for_att
self.layers.extend(
[self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]
)
self.num_layers = len(self.layers)
if cfg.encoder.normalize_before:
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
else:
self.layer_norm = None
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160)
def build_encoder_layer(self, cfg):
layer = TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att)
checkpoint = cfg.checkpoint_activations
if checkpoint:
raise ValueError("We don't support checkpoint_activations for now! Please set cfg.checkpoint_activations=False.")
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward_embedding(
self, src_tokens, token_embedding: Optional[torch.Tensor] = None
):
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
return x, embed
def forward(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
uniformity_layers: Optional[List[int]] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(
src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers
)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
uniformity_layers: Optional[List[int]] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
# account for padding while computing the representation
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
x_len = x.shape[0]
pos_seq = torch.arange(0, x_len).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, pos_v = self.pos_emb(pos_seq)
else:
pos_k = None
encoder_states = []
uniformity_hiddens = []
if return_all_hiddens:
encoder_states.append(x)
if uniformity_layers is not None and 0 in uniformity_layers:
x = F.normalize(x.float(), dim=-1).type_as(x)
uniformity_hiddens.append(x)
# encoder layers
for i, layer in enumerate(self.layers):
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None,
pos_bias=pos_k,
)
if uniformity_layers is not None and i+1 in uniformity_layers:
x = F.normalize(x.float(), dim=-1).type_as(x)
uniformity_hiddens.append(x)
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
src_lengths = (
src_tokens.ne(self.padding_idx)
.sum(dim=1, dtype=torch.int32)
.reshape(-1, 1)
.contiguous()
)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"uniformity_hiddens": uniformity_hiddens, # List[T x B x C]
"src_tokens": [],
"src_lengths": [src_lengths],
}
def forward_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
return self.forward(
src_tokens=net_input["src_tokens"],
src_lengths=net_input["src_lengths"],
)
else:
return self.forward_non_torchscript(net_input)
@torch.jit.unused
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
encoder_input = {
k: v for k, v in net_input.items() if k != "prev_output_tokens"
}
return self.forward(**encoder_input)
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
src_tokens = []
else:
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
src_lengths = []
else:
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T
"src_lengths": src_lengths, # B x 1
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
print("deleting {0}".format(weights_key))
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
version_key = "{}.version".format(name)
if utils_item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
def set_num_updates(self, num_updates):
"""State from trainer to pass along to model at every update."""
def _apply(m):
if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
self.apply(_apply)
class TransformerEncoderLayerBase(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.encoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, cfg, has_relative_attention_bias=False, scaling_for_att=1.0):
super().__init__()
self.cfg = cfg
self.embed_dim = cfg.encoder.embed_dim
self.quant_noise = cfg.quant_noise.pq
self.quant_noise_block_size = cfg.quant_noise.pq_block_size
self.self_attn = self.build_self_attention(self.embed_dim, cfg, has_relative_attention_bias=has_relative_attention_bias, scaling_for_att=scaling_for_att)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=self.__class__.__name__
)
self.activation_fn = get_activation_fn(activation=cfg.activation_fn)
activation_dropout_p = cfg.activation_dropout
if activation_dropout_p == 0:
# for backwards compatibility with models that use cfg.relu_dropout
activation_dropout_p = cfg.relu_dropout or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = cfg.encoder.normalize_before
self.fc1 = self.build_fc1(
self.embed_dim,
cfg.encoder.ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
cfg.encoder.ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embed_dim // cfg.encoder.attention_heads)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_self_attention(self, embed_dim, cfg, has_relative_attention_bias=False, scaling_for_att=1.0):
return MultiheadAttention(
embed_dim,
cfg.encoder.attention_heads,
dropout=cfg.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
has_relative_attention_bias=has_relative_attention_bias,
scaling_for_att=scaling_for_att,
)
def residual_connection(self, x, residual):
return residual + x
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(
self,
x,
encoder_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor] = None,
pos_bias=None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
position_bias=pos_bias,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
class TransformerEncoder(nn.Module):
"""
wav2vec-style transformer encoder.
"""
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.required_seq_len_multiple = args.required_seq_len_multiple
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
layers = []
self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False)
for _ in range(args.encoder_layers):
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
has_relative_attention_bias=self.use_rel_pos_enc,
scaling_for_att=getattr(args, "scaling_for_att", 1.0)
)
if args.checkpoint_activations:
raise ValueError("We don't support checkpoint_activations for now! Please set checkpoint_activations=False.")
layers.append(layer)
self.layers = nn.ModuleList(layers)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160)
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, layer=None, conv_pos=True):
x, layer_results = self.extract_features(x, padding_mask, layer, conv_pos)
if self.layer_norm_first and (layer is None or layer >= len(self.layers) - 1):
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, tgt_layer=None, conv_pos=True):
if padding_mask is not None:
x = index_put(x, padding_mask, 0)
if conv_pos:
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
x_len = x.shape[0]
pos_seq = torch.arange(0, x_len).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, pos_v = self.pos_emb(pos_seq)
else:
pos_k = None
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k)
if tgt_layer is not None:
# unpad if needed
if pad_length > 0:
layer_results.append(
x[:-pad_length]
# (
# x[:-pad_length],
# z[:, :-pad_length, :-pad_length]
# if z is not None
# else z,
# )
)
else:
# layer_results.append((x, z))
layer_results.append(x)
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
return x, layer_results
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.args.max_positions
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict
class TransformerSentenceEncoderLayer(nn.Module):
"""
wav2vec-style transformer layer
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
has_relative_attention_bias: bool = False,
scaling_for_att: float = 1.0,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_fn = get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
has_relative_attention_bias=has_relative_attention_bias,
scaling_for_att=scaling_for_att
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
if has_relative_attention_bias:
self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
pos_bias=None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
if pos_bias is not None:
pos_bias = self.norm_k(pos_bias)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
position_bias=pos_bias,
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn
class FairseqDropout(nn.Module):
def __init__(self, p, module_name=None):
super().__init__()
self.p = p
self.module_name = module_name
self.apply_during_inference = False
def forward(self, x, inplace: bool = False):
if self.p > 0 and (self.training or self.apply_during_inference):
return F.dropout(x, p=self.p, training=True, inplace=inplace)
else:
return x
def make_generation_fast_(
self,
name: str,
retain_dropout: bool = False,
retain_dropout_modules: Optional[List[str]] = None,
**kwargs
):
if retain_dropout:
if retain_dropout_modules is not None and self.module_name is None:
logger.warning(
"Cannot enable dropout during inference for module {} "
"because module_name was not set".format(name)
)
elif (
retain_dropout_modules is None # if None, apply to all modules
or self.module_name in retain_dropout_modules
):
logger.info(
"Enabling dropout during inference for module: {}".format(name)
)
self.apply_during_inference = True
else:
logger.info("Disabling dropout for module: {}".format(name))
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.onnx_trace = False
if self.padding_idx is not None:
self.max_positions = self.num_embeddings - self.padding_idx - 1
else:
self.max_positions = self.num_embeddings
def forward(
self,
input: Tensor,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
positions: Optional[Tensor] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
assert (positions is None) or (
self.padding_idx is None
), "If positions is pre-computed then padding_idx should not be set."
if positions is None:
if incremental_state is not None:
# positions is the same for every token when decoding a single step
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
positions = torch.zeros(
(1, 1), device=input.device, dtype=input.dtype
).fill_(int(self.padding_idx + input.size(1)))
else:
positions = utils_make_positions(
input, self.padding_idx, onnx_trace=self.onnx_trace
)
positions = torch.clamp(positions, max=self.padding_idx + self.max_positions)
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx if padding_idx is not None else 0
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
)
self.onnx_trace = False
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self.max_positions = int(1e5)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(
self,
input,
incremental_state: Optional[Any] = None,
timestep: Optional[Tensor] = None,
positions: Optional[Any] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
bspair = torch.onnx.operators.shape_as_tensor(input)
bsz, seq_len = bspair[0], bspair[1]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
if self.onnx_trace:
return (
self.weights.index_select(index=self.padding_idx + pos, dim=0)
.unsqueeze(1)
.repeat(bsz, 1, 1)
)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = utils_make_positions(
input, self.padding_idx, onnx_trace=self.onnx_trace
)
if self.onnx_trace:
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat(
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
)
embeddings = torch.onnx.operators.reshape_from_tensor_shape(
flat_embeddings, embedding_shape
)
return embeddings
return (
self.weights.index_select(0, positions.view(-1))
.view(bsz, seq_len, -1)
.detach()
)
try:
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
has_fused_layernorm = True
class FusedLayerNorm(_FusedLayerNorm):
@torch.jit.unused
def forward(self, x):
if not x.is_cuda:
return super().forward(x)
else:
with torch.cuda.device(x.device):
return super().forward(x)
except ImportError:
has_fused_layernorm = False
class Fp32LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class LayerDropModuleList(nn.ModuleList):
"""
A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
We refresh the choice of which layers to drop every time we iterate
over the LayerDropModuleList instance. During evaluation we always
iterate over all layers.
Usage::
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
for layer in layers: # this might iterate over layers 1 and 3
x = layer(x)
for layer in layers: # this might iterate over all layers
x = layer(x)
for layer in layers: # this might not iterate over any layers
x = layer(x)
Args:
p (float): probability of dropping out each layer
modules (iterable, optional): an iterable of modules to add
"""
def __init__(self, p, modules=None):
super().__init__(modules)
self.p = p
def __iter__(self):
dropout_probs = torch.empty(len(self)).uniform_()
for i, m in enumerate(super().__iter__()):
if not self.training or (dropout_probs[i] > self.p):
yield m
class RelativePositionalEncoding(torch.nn.Module):
def __init__(self, d_model, maxlen=1000, embed_v=False):
super(RelativePositionalEncoding, self).__init__()
self.d_model = d_model
self.maxlen = maxlen
self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
if embed_v:
self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
self.embed_v = embed_v
def forward(self, pos_seq, incremental_state=None):
pos_seq[pos_seq < -self.maxlen] = -self.maxlen
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
pos_seq = pos_seq + self.maxlen
if incremental_state is not None:
pos_seq = pos_seq[-1:]
if self.embed_v:
return self.pe_k(pos_seq), self.pe_v(pos_seq)
else:
return self.pe_k(pos_seq), None
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
has_relative_attention_bias=False,
scaling_for_att=1.0
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.has_relative_attention_bias = has_relative_attention_bias
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.scaling_for_att = scaling_for_att
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if (
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
and not self.has_relative_attention_bias
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
q *= (1 / self.scaling_for_att)
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
if position_bias is not None: ## first order
## position_bias: [241, 241, 64]
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
#print ("reshape_q: ", reshape_q.size())
B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
#print ("B: ", B.size()) ## [241, 492, 241]
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
#print ("B 2: ", B.size())
attn_weights += B
attn_weights *= self.scaling_for_att
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if self.scaling_for_att > 1.0:
attn_weights = attn_weights - attn_weights.detach().max(dim=-1, keepdim=True)[0]
if before_softmax:
return attn_weights, v
attn_weights_float = softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
if self.encoder_decoder_attention and input_buffer_k.size(
0
) == new_order.size(0):
break
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + "in_proj_weight"):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k)
k_bias = prefix + "in_proj_bias"
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias")
for k in keys_to_remove:
del state_dict[k]
for key, value in items_to_add.items():
state_dict[key] = value
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(
n_in,
n_out,
k,
stride,
is_layer_norm=False,
is_group_norm=False,
conv_bias=False,
):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert (
is_layer_norm and is_group_norm
) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=True),
TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
in_d = 1
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
)
)
in_d = dim
def forward(self, x):
# BxT -> BxCxT
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x
class TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None):
super().__init__()
self.deconstruct_idx = deconstruct_idx
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(-2, -1)
class Fp32GroupNorm(nn.GroupNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.group_norm(
input.float(),
self.num_groups,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class Rotate3D(nn.Module):
"""
(T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D)
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(1, 2, 0)
class SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
if causal:
self.remove = kernel_size - 1
else:
self.remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
from . import data, tasks, criterions, models
# @package _group_
defaults:
- model: null
hydra:
run:
dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
sweep:
dir: ${common_eval.results_path}
subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
task:
_name: joint_sc2t_pretraining
data: ???
label_dir: ???
labels: ["ltr"]
store_labels: true
single_target: true
fine_tuning: true
normalize: ??? # must be consistent with pre-training
add_decoder_target: false
pad_audio: false
random_crop: true
hubert_tokenizer: "none"
sp_path: None
decoding:
type: fairseqlm
lexicon: ???
lmpath: ???
beamthreshold: 25
beam: 500
lmweight: 2
wordscore: -1
silweight: 0
unique_wer_file: true
common_eval:
results_path: ???
path: ???
post_process: letter
dataset:
max_tokens: 1100000
gen_subset: ???
# @package _group_
defaults:
- model: null
hydra:
run:
dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
sweep:
dir: ${common_eval.results_path}
subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
task:
_name: joint_sc2t_pretraining
data: ???
label_dir: ???
labels: ["ltr"]
store_labels: true
single_target: true
fine_tuning: true
normalize: ??? # must be consistent with pre-training
add_decoder_target: false
pad_audio: false
random_crop: true
hubert_tokenizer: "none"
sp_path: None
decoding:
type: kenlm
lexicon: ???
lmpath: ???
beamthreshold: 100
beam: 500
lmweight: 2
wordscore: -1
silweight: 0
unique_wer_file: true
common_eval:
results_path: ???
path: ???
post_process: letter
dataset:
max_tokens: 1100000
gen_subset: ???
# @package _group_
defaults:
- model: null
hydra:
run:
dir: ${common_eval.results_path}/viterbi
sweep:
dir: ${common_eval.results_path}
subdir: viterbi
task:
_name: joint_sc2t_pretraining
data: ???
label_dir: ???
labels: ["ltr"]
store_labels: true
single_target: true
fine_tuning: true
normalize: ??? # must be consistent with pre-training
add_decoder_target: false
pad_audio: false
random_crop: true
hubert_tokenizer: "none"
sp_path: None
decoding:
type: viterbi
unique_wer_file: true
common_eval:
results_path: ???
path: ???
post_process: letter
dataset:
batch_size: 1
gen_subset: ???
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tblog
seed: 1337
checkpoint:
save_interval: 1
keep_last_epochs: 1
keep_best_checkpoints: -1
best_checkpoint_metric: wer
restore_file: checkpoint_last.pt
distributed_training:
ddp_backend: legacy_ddp
find_unused_parameters: true
distributed_world_size: 1
distributed_port: -1
nprocs_per_node: 8
task:
_name: joint_sc2t_pretraining
data: ???
fine_tuning: true
label_dir: ???
normalize: false # must be consistent with pre-training
labels: ["ltr"]
store_labels: true
single_target: true
add_decoder_target: false
pad_audio: false
random_crop: true
hubert_tokenizer: "none"
sp_path: None
dataset:
num_workers: 0
max_tokens: 1600000
skip_invalid_size_inputs_valid_test: true
train_subset: train_100
valid_subset: dev_other
required_batch_size_multiple: 1
criterion:
_name: ctc
zero_infinity: true
optimization:
max_update: 30000
lr: [0.00001]
sentence_avg: true
update_freq: [1]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
weight_decay: 0.0
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: speechlm_ctc
w2v_path: ???
apply_mask: true
mask_prob: 0.65
mask_channel_prob: 0.5
mask_channel_length: 64
layerdrop: 0.1
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
- model.w2v_path
- dataset.train_subset
- dataset.valid_subset
- criterion.wer_kenlm_model
- criterion.wer_lexicon
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tblog
checkpoint:
save_interval: 1
keep_last_epochs: 5
keep_best_checkpoints: 5
best_checkpoint_metric: wer
restore_file: checkpoint_last.pt
distributed_training:
ddp_backend: legacy_ddp
find_unused_parameters: true
distributed_world_size: 32
distributed_port: -1
nprocs_per_node: 8
task:
_name: joint_sc2t_pretraining
data: ???
fine_tuning: true
label_dir: ???
normalize: true # must be consistent with pre-training
labels: ["ltr"]
store_labels: true
single_target: true
add_decoder_target: false
pad_audio: false
random_crop: true
hubert_tokenizer: "none"
sp_path: None
dataset:
num_workers: 0
max_tokens: 900000
skip_invalid_size_inputs_valid_test: true
train_subset: train_960
valid_subset: dev_other
required_batch_size_multiple: 1
criterion:
_name: ctc
zero_infinity: true
optimization:
max_update: 200000
lr: [0.00001]
sentence_avg: true
update_freq: [1]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
weight_decay: 0.0
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: speechlm_ctc
w2v_path: ???
apply_mask: true
mask_prob: 0.5
mask_channel_prob: 0.25
mask_channel_length: 64
layerdrop: 0.0
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
- model.w2v_path
- dataset.train_subset
- dataset.valid_subset
- criterion.wer_kenlm_model
- criterion.wer_lexicon
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
seed: 1337
tensorboard_logdir: tblog
checkpoint:
save_dir: ???
save_interval: 4
keep_last_epochs: 4
save_interval_updates: 50000
keep_interval_updates: -1
keep_interval_updates_pattern: 50000
# no_epoch_checkpoints: true
distributed_training:
ddp_backend: legacy_ddp
distributed_backend: 'nccl'
distributed_port: -1
distributed_world_size: 32
nprocs_per_node: 8
find_unused_parameters: true
task:
_name: joint_sc2t_pretraining
data: ???
label_dir: ???
labels: ???
label_rate: ${model.label_rate}
store_labels: true
sample_rate: 16000
max_sample_size: 250000
min_sample_size: 32000
pad_audio: false
random_crop: true
normalize: false # must be consistent with extractor
add_decoder_target: false
text_cfg:
seed: ${common.seed}
text_data: ???
data_config: config.yaml
sample_break_mode: eos
tokens_per_sample: 1024
shorten_method: "random_crop"
text_maxtokens_ratio: 1.0
dataset:
num_workers: 6
max_tokens: 1400000
skip_invalid_size_inputs_valid_test: true
validate_interval: ${checkpoint.save_interval}
validate_interval_updates: ${checkpoint.save_interval_updates}
required_batch_size_multiple: 1
criterion:
_name: speechlm_criterion
pred_masked_weight: 1.0
pred_nomask_weight: 0.0
loss_weights: [10,]
text_ctc_weight: 0.1
text_mum_weight: 0.0
optimization:
max_update: 400000
lr: [0.0005]
clip_norm: 10.0
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: polynomial_decay
warmup_updates: 32000
model:
_name: speechlm
label_rate: ???
skip_masked: false
skip_nomask: false
mask_prob: 0.80
extractor_mode: default
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
final_dim: 256
activation_fn: "gelu"
encoder_layers: 6
encoder_attention_heads: 8
encoder_layerdrop: 0.1
dropout_input: 0.1
dropout_features: 0.1
dropout: 0.1
attention_dropout: 0.1
feature_grad_mult: 0.1
untie_final_proj: true
activation_dropout: 0.0
use_rel_pos_enc: true
add_unit_encoder: true
add_text_ctc: true
mask_u2t: true
compute_mum: false
mix_with_unit: true
text_transformer:
activation_fn: ${model.activation_fn}
dropout: ${model.dropout}
attention_dropout: ${model.attention_dropout}
activation_dropout: ${model.activation_dropout}
max_source_positions: 3000
no_scale_embedding: true
layernorm_embedding: true
no_token_positional_embeddings: false
encoder:
embed_dim: 768
ffn_embed_dim: 3072
layers: 6
attention_heads: 8
normalize_before: false
learned_pos: true
layerdrop: ${model.encoder_layerdrop}
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
seed: 1234
tensorboard_logdir: tblog
checkpoint:
save_dir: ???
save_interval: 1
keep_last_epochs: 4
save_interval_updates: 10000
keep_interval_updates: 40
keep_interval_updates_pattern: 10000
# no_epoch_checkpoints: true
distributed_training:
ddp_backend: legacy_ddp
distributed_backend: 'nccl'
distributed_port: -1
distributed_world_size: 32
nprocs_per_node: 8
find_unused_parameters: true
task:
_name: joint_sc2t_pretraining
data: ???
label_dir: ???
labels: ???
label_rate: ${model.label_rate}
store_labels: true
sample_rate: 16000
max_sample_size: 250000
min_sample_size: 32000
pad_audio: false
random_crop: true
normalize: true # must be consistent with extractor
add_decoder_target: false
text_cfg:
seed: ${common.seed}
text_data: ???
data_config: config.yaml
sample_break_mode: eos
tokens_per_sample: 1024
shorten_method: "random_crop"
text_maxtokens_ratio: 1.0
dataset:
num_workers: 1
max_tokens: 900000
skip_invalid_size_inputs_valid_test: true
validate_interval: ${checkpoint.save_interval}
validate_interval_updates: ${checkpoint.save_interval_updates}
required_batch_size_multiple: 2
criterion:
_name: speechlm_criterion
pred_masked_weight: 1.0
pred_nomask_weight: 0.0
loss_weights: [10,]
text_ctc_weight: 0.1
text_mum_weight: 0.0
optimization:
max_update: 400000
lr: [0.001]
clip_norm: 1.0
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: polynomial_decay
warmup_updates: 32000
model:
_name: speechlm
label_rate: ???
activation_fn: "gelu"
encoder_layers: 12
encoder_embed_dim: 1024
encoder_ffn_embed_dim: 4096
encoder_attention_heads: 16
final_dim: 256
skip_masked: false
skip_nomask: false
mask_prob: 0.80
extractor_mode: layer_norm
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
encoder_layerdrop: 0.0
dropout_input: 0.0
dropout_features: 0.0
dropout: 0.0
attention_dropout: 0.0
layer_norm_first: true
feature_grad_mult: 1.0
untie_final_proj: true
activation_dropout: 0.0
use_rel_pos_enc: true
add_unit_encoder: true
add_text_ctc: true
mask_u2t: true
compute_mum: false
mix_with_unit: true
scaling_for_att: 32
text_transformer:
activation_fn: ${model.activation_fn}
dropout: ${model.dropout}
attention_dropout: ${model.attention_dropout}
activation_dropout: ${model.activation_dropout}
max_source_positions: 3000
no_scale_embedding: true
layernorm_embedding: true
no_token_positional_embeddings: false
encoder:
embed_dim: 1024
ffn_embed_dim: 4096
layers: 12
attention_heads: 16
normalize_before: ${model.layer_norm_first}
learned_pos: true
layerdrop: ${model.encoder_layerdrop}
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
criterion_name = file[: file.find(".py")]
importlib.import_module(
"speechlm.criterions." + criterion_name
)
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
from typing import List, Dict, Any
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import lengths_to_mask
from fairseq.models.fairseq_model import FairseqEncoderModel
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
ntokens = (~pad_mask).sum()
nll_loss = nll_loss.sum() / ntokens
smooth_loss = smooth_loss.sum() / ntokens
eps_i = epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
@dataclass
class FastText2UnitCriterionConfig(FairseqDataclass):
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
dur_loss_weight: float = field(
default=1.0,
metadata={"help": "scale of duration loss"},
)
report_accuracy: bool = field(
default=True,
metadata={"help": "report decoder accuracy metric"},
)
@register_criterion("fasttext2unit_criterion", dataclass=FastText2UnitCriterionConfig)
class FastText2UnitLoss(FairseqCriterion):
def __init__(self,
task,
label_smoothing=0,
dur_loss_weight=1.0,
report_accuracy=False,
):
super().__init__(task)
self.eps = label_smoothing
self.dur_loss_weight = dur_loss_weight
self.pad_idx = task.tgt_dict.pad()
self.report_accuracy = report_accuracy
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
tgt_lens = sample["target_lengths"]
_feat_out, _feat_out_post, out_lens, log_dur_out, pitch_out, energy_out = model(
src_tokens=src_tokens,
src_lengths=src_lens,
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
durations=sample["durations"],
pitches=sample["pitches"],
energies=sample["energies"],
)
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
tgt_mask = lengths_to_mask(sample["target_lengths"])
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
target = sample["target"].long()
ce_loss, nll_loss = label_smoothed_nll_loss(lprobs, target, self.eps, self.padding_idx, reduce=True)
pitches, energies = sample["pitches"], sample["energies"]
if pitches is not None:
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
else:
pitch_loss = 0
if energies is not None:
energy_out, energies = energy_out[src_mask], energies[src_mask]
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
else:
energy_loss = 0
log_dur_out = log_dur_out[src_mask]
dur = sample["durations"].float()
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
log_dur = torch.log(dur + 1)[src_mask]
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
dur_loss = self.dur_loss_weight * dur_loss
loss = ce_loss + dur_loss + pitch_loss + energy_loss
sample_size = sample["nsentences"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"ce_loss": utils.item(ce_loss.data),
"dur_loss": utils.item(dur_loss.data),
"pitch_loss": utils.item(pitch_loss),
"energy_loss": utils.item(energy_loss),
}
if self.report_accuracy:
n_correct = lprobs.argmax(-1).masked_select(tgt_mask).eq(target.masked_select(tgt_mask)).sum()
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = tgt_mask.sum()
return loss, 1, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
ns = [log.get("sample_size", 0) for log in logging_outputs]
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in [
"loss",
"ce_loss",
"dur_loss",
"pitch_loss",
"energy_loss",
]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
metrics.log_scalar(key, val, ntot, round=3)
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
# inference metrics
if "targ_frames" not in logging_outputs[0]:
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import logging
import math
import re
from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
from fairseq.dataclass import FairseqDataclass
logger = logging.getLogger(__name__)
@dataclass
class HSTCriterionConfig(FairseqDataclass):
pred_masked_weight: float = field(
default=1.0,
metadata={"help": "weight for predictive loss for masked frames"},
)
pred_nomask_weight: float = field(
default=0.0,
metadata={"help": "weight for predictive loss for unmasked frames"},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
text_ctc_weight: float = field(
default=0.1,
metadata={"help": "weights for text CTC Loss, loss will be (hubert_loss + dec_weight * CE_Loss + text_weight * (CE_Loss + CTC_loss))"},
)
text_mum_weight: float = field(
default=0.0,
metadata={"help": "masked unit modeling weight from the text end"},
)
report_accuracy: bool = field(
default=True,
metadata={"help": "report decoder accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
no_ctc_blank: bool = field(
default=False,
metadata={"help": "mask out the blank of ctc, only when dec_loss_type=ctc"},
)
@register_criterion("speechlm_criterion", dataclass=HSTCriterionConfig)
class SpeechLMCriterion(FairseqCriterion):
def __init__(
self,
task,
pred_masked_weight,
pred_nomask_weight,
loss_weights=None,
log_keys=None,
text_ctc_weight=0.1,
text_mum_weight=0,
report_accuracy=False,
ignore_prefix_size=0,
no_ctc_blank=False,
):
super().__init__(task)
self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
self.text_ctc_weight = text_ctc_weight
self.text_mum_weight = text_mum_weight
self.report_accuracy = report_accuracy
self.ignore_prefix_size = ignore_prefix_size
self.no_ctc_blank = no_ctc_blank
self.padding_idx = task.dictionaries[0].pad()
self.eos_idx = task.dictionaries[0].eos()
self.blank_idx = task.dictionaries[0].bos()
def compute_hubert_loss(self, model, net_output, reduction, suffix=''):
loss = 0
sample_size = []
logging_output = {}
loss_m_list = []
logp_m_list = model.get_logits(net_output, True)
targ_m_list = model.get_targets(net_output, True)
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_{i}{suffix}"] = loss_m.detach().item()
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size.append(targ_m_list[0].numel())
loss_u_list = []
logp_u_list = model.get_logits(net_output, False)
targ_u_list = model.get_targets(net_output, False)
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_{i}{suffix}"] = loss_u.detach().item()
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size.append(targ_u_list[0].numel())
sample_size = np.mean(sample_size)
def compute_correct(logits, targets):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == targets
min = logits.argmin(-1) == targets
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
corr_m, count_m = compute_correct(logp_m, targ_m)
logging_output[f"correct_m_{i}{suffix}"] = corr_m
logging_output[f"count_m_{i}{suffix}"] = count_m
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
corr_u, count_u = compute_correct(logp_u, targ_u)
logging_output[f"correct_u_{i}{suffix}"] = corr_u
logging_output[f"count_u_{i}{suffix}"] = count_u
return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
reduction = "sum" if reduce else "none"
if "net_input" in sample:
text_sample = None
else:
text_sample = sample.get("text_paired")
sample = sample.get("speech")
### 1. L_UMLM: do hubert forward and loss computation
sample["modality"] = "speech"
net_output = model(target_list=sample["target_list"], **sample["net_input"])
loss, sample_size, logging_output = self.compute_hubert_loss(
model,
net_output,
reduction,
)
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses, names = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
### 2. do text forward and loss computation
if text_sample is not None:
text_sample["modality"] = "text"
## 2.1 re-loading "target_list", in default case, target_list = [src_tokens],
## while in case of using "unit-phone-char" structure, target_list will be [ref_tokens]
text_sample["net_input"]["target_list"] = [
text_sample.get("ref_tokens", text_sample["net_input"]["src_tokens"].clone()),
]
text_net_output = model(**text_sample["net_input"])
### 2.2 L_UMLM (text-end, not applied by default)
if self.text_mum_weight > 0:
loss_u2t, sample_size_u2t, logging_output_u2t = self.compute_hubert_loss(
model,
text_net_output,
reduction,
suffix="_u2t",
)
loss += self.text_mum_weight * loss_u2t * sample_size / sample_size_u2t
logging_output.update(logging_output_u2t)
### 2.3 L_UCTC
text_sample_size = text_sample["ntokens"]
if self.text_ctc_weight > 0:
text_ctc_loss = self.compute_ctc_loss(model, text_net_output, text_sample["target"], reduction=reduction)
loss += self.text_ctc_weight * text_ctc_loss * sample_size / text_sample_size
logging_output["text_ctc_loss"] = utils.item(text_ctc_loss)
logging_output["text_sample_size"] = text_sample_size
logging_output = {
"loss": utils.item(loss) if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel() + (text_sample["id"].numel() if text_sample is not None else 0),
"sample_size": sample_size,
**logging_output,
}
return loss, sample_size, logging_output
def compute_ctc_loss(self, model, net_output, target, reduction):
logits = net_output["encoder_out_ctc"][0] # (T, B, C) from the code-encoder
if self.no_ctc_blank:
## set prob of <blank> to -inf
logits = logits.float()
logits[:, :, self.blank_idx] = -1000000.0
lprobs = F.log_softmax(logits.float(), dim=-1)
encoder_padding_mask = net_output["encoder_padding_mask"][0]
non_padding_mask = ~encoder_padding_mask
input_lengths = non_padding_mask.long().sum(-1)
pad_mask = (target != self.padding_idx) & (target != self.eos_idx)
targets_flat = target.masked_select(pad_mask)
target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
lprobs,
targets_flat,
input_lengths,
target_lengths,
blank=self.blank_idx,
reduction=reduction,
zero_infinity=True,
)
return loss
def compute_ce_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
if sample["modality"] == "speech":
target = sample["decoder_target"]
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous()
else:
target = sample["target"]
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
counts = {}
for lk in logging_outputs[0].keys():
if lk.startswith("count_"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val)
counts[lk] = val
for lk in logging_outputs[0].keys():
if lk.startswith("loss_"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
elif lk.startswith("correct_"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
if "text_sample_size" in logging_outputs[0]:
text_sample_size = sum(log.get("text_sample_size", 0) for log in logging_outputs)
for lk in logging_outputs[0].keys():
if lk.startswith("text_") and lk.endswith("_loss"):
val = sum(log.get(lk, 0) for log in logging_outputs)
metrics.log_scalar(lk, val / text_sample_size / math.log(2), round=3)
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError()
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import bisect
import numpy as np
from torch.utils.data.dataloader import default_collate
from fairseq.data import FairseqDataset
class ConcatDataset(FairseqDataset):
@staticmethod
def cumsum(sequence, sample_ratios):
r, s = [], 0
for e, ratio in zip(sequence, sample_ratios):
curr_len = int(ratio * len(e))
r.append(curr_len + s)
s += curr_len
return r
def __init__(self, datasets, sample_ratios=1):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, "datasets should not be an empty iterable"
self.datasets = list(datasets)
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx][sample_idx]
def _get_dataset_and_sample_index(self, idx: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return dataset_idx, sample_idx
def collater(self, samples, **extra_args):
# For now only supports datasets with same underlying collater implementations
if hasattr(self.datasets[0], "collater"):
return self.datasets[0].collater(samples, **extra_args)
else:
return default_collate(samples, **extra_args)
def size(self, idx: int):
"""
Return an example's size as a float or tuple.
"""
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx].size(sample_idx)
def num_tokens(self, index: int):
return np.max(self.size(index))
def attr(self, attr: str, index: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
return getattr(self.datasets[dataset_idx], attr, None)
@property
def sizes(self):
_dataset_sizes = []
for ds, sr in zip(self.datasets, self.sample_ratios):
if isinstance(ds.sizes, np.ndarray):
_dataset_sizes.append(np.tile(ds.sizes, sr))
else:
# Only support underlying dataset with single size array.
assert isinstance(ds.sizes, list)
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
return np.concatenate(_dataset_sizes)
@property
def supports_prefetch(self):
return all(d.supports_prefetch for d in self.datasets)
def ordered_indices(self):
"""
Returns indices sorted by length. So less padding is needed.
"""
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
# special handling for concatenating lang_pair_datasets
if getattr(self.datasets[0], "shuffle", False):
indices = np.random.permutation(len(self)).astype(np.int64)
else:
indices = np.arange(len(self), dtype=np.int64)
sizes = self.sizes
tgt_sizes = (
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
)
src_sizes = (
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
)
# sort by target length, then source length
if tgt_sizes is not None:
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
else:
return np.argsort(self.sizes)
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
if getattr(ds, "supports_prefetch", False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import itertools
import logging
import io
import os
import sys
import time
from pathlib import Path
from typing import Any, List, Optional, Union, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils, Dictionary
from fairseq.data.fairseq_dataset import FairseqDataset
from fairseq.data.audio.audio_utils import (
read_from_stored_zip,
is_sf_audio_data,
)
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
logger = logging.getLogger(__name__)
def parse_path(path: str) -> Tuple[str, List[int]]:
"""Parse data path which is either a path to
1. a .npy/.wav/.flac/.ogg file
2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
Args:
path (str): the data path to parse
Returns:
file_path (str): the file path
slice_ptr (list of int): empty in case 1;
byte offset and length for the slice in case 2
"""
if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
_path, slice_ptr = path, []
else:
_path, *slice_ptr = path.split(":")
if not Path(_path).is_file():
raise FileNotFoundError(f"File not found: {_path}")
assert len(slice_ptr) in {0, 1, 2}, f"Invalid path: {path}"
slice_ptr = [int(i) for i in slice_ptr]
return _path, slice_ptr
def load_audio(manifest_path, max_keep, min_keep, retry_times=5):
n_long, n_short = 0, 0
names, inds, sizes, chunk_names, chunk_indices = [], [], [], [], []
for i in range(retry_times):
with open(manifest_path) as f:
root = f.readline().strip()
for ind, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) == 2, line
sz = int(items[1])
if min_keep is not None and sz < min_keep:
n_short += 1
elif max_keep is not None and sz > max_keep:
n_long += 1
else:
fname = items[0].split(":")
if len(fname) > 2:
if len(chunk_names) == 0 or fname[0] != chunk_names[-1]:
chunk_names.append(fname[0])
chunk_indices.append(len(names))
names.append(items[0])
inds.append(ind)
sizes.append(sz)
if len(names) == 0:
logger.warn(f"Fail to load manifest for the {i} time")
time.sleep(1)
continue
else:
break
tot = ind + 1
logger.info(
(
f"max_keep={max_keep}, min_keep={min_keep}, "
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
)
)
return root, names, inds, tot, sizes, chunk_names, chunk_indices
def load_label(label_path, inds, tot, retry_times=5):
for i in range(retry_times):
with open(label_path) as f:
labels = [line.rstrip() for line in f]
if len(labels) == 0:
logger.warn(f"Fail to load label for the {i} time")
time.sleep(1)
continue
else:
break
assert (
len(labels) == tot
), f"number of labels does not match ({len(labels)} != {tot})"
labels = [labels[i] for i in inds]
return labels
def load_label_offset(label_path, inds, tot, retry_times=5):
for i in range(retry_times):
with open(label_path) as f:
code_lengths = [len(line.encode("utf-8")) for line in f]
if len(code_lengths) == 0:
logger.warn(f"Fail to load label for the {i} time")
time.sleep(1)
continue
else:
break
assert (
len(code_lengths) == tot
), f"number of labels does not match ({len(code_lengths)} != {tot})"
offsets = list(itertools.accumulate([0] + code_lengths))
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
return offsets
def verify_label_lengths(
audio_sizes,
audio_rate,
label_path,
label_rate,
inds,
tot,
tol=0.1, # tolerance in seconds
):
if label_rate < 0:
logger.info(f"{label_path} is sequence label. skipped")
return
with open(label_path) as f:
lengths = [len(line.rstrip().split()) for line in f]
assert len(lengths) == tot
lengths = [lengths[i] for i in inds]
num_invalid = 0
for i, ind in enumerate(inds):
dur_from_audio = audio_sizes[i] / audio_rate
dur_from_label = lengths[i] / label_rate
if abs(dur_from_audio - dur_from_label) > tol:
logger.warning(
(
f"audio and label duration differ too much "
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
f"in line {ind+1} of {label_path}. Check if `label_rate` "
f"is correctly set (currently {label_rate}). "
f"num. of samples = {audio_sizes[i]}; "
f"label length = {lengths[i]}"
)
)
num_invalid += 1
if num_invalid > 0:
logger.warning(
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
)
class HubertDataset(FairseqDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_paths: List[str],
label_rates: Union[List[float], float], # -1 for sequence labels
pad_list: List[str],
eos_list: List[str],
label_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
max_sample_size: Optional[int] = None,
shuffle: bool = True,
pad_audio: bool = False,
normalize: bool = False,
store_labels: bool = True,
random_crop: bool = False,
single_target: bool = False,
tgt_dict: Optional[Dictionary] = None,
add_decoder_target: bool = False,
fine_tuning: bool = False,
tgt_lang_idx: int = None,
tokenizer = None,
mbart_style_lang_id: bool = False,
retry_times: int = 5,
reduce_label_for_dec: bool = True,
):
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.chunk_names, self.chunk_indices = load_audio(
manifest_path, max_keep_sample_size, min_keep_sample_size, retry_times
)
self.sample_rate = sample_rate
self.shuffle = shuffle
self.random_crop = random_crop
self.tgt_dict = tgt_dict
self.add_decoder_target = add_decoder_target
self.fine_tuning = fine_tuning
self.num_labels = len(label_paths)
self.pad_list = pad_list
self.eos_list = eos_list
self.label_processors = label_processors
self.single_target = single_target
self.epoch = 0
self.label_rates = (
[label_rates for _ in range(len(label_paths))]
if isinstance(label_rates, int)
else label_rates
)
self.store_labels = store_labels
if store_labels:
self.label_list = [load_label(p, inds, tot, retry_times) for p in label_paths]
else:
self.label_paths = label_paths
self.label_offsets_list = [
load_label_offset(p, inds, tot, retry_times) for p in label_paths
]
assert label_processors is None or len(label_processors) == self.num_labels
for label_path, label_rate in zip(label_paths, self.label_rates):
verify_label_lengths(
self.wav_sizes, sample_rate, label_path, label_rate, inds, tot
)
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.pad_audio = pad_audio
self.normalize = normalize
self.tgt_lang_idx = tgt_lang_idx
self.tokenizer = tokenizer
self.mbart_style_lang_id = mbart_style_lang_id
self.retry_times = retry_times
self.reduce_label_for_dec = reduce_label_for_dec
logger.info(
f"pad_audio={pad_audio}, random_crop={random_crop}, tgt_lang_idx={self.tgt_lang_idx}, reduce_label_for_dec={reduce_label_for_dec}, "
f"mbart_style_lang_id={mbart_style_lang_id}, normalize={normalize}, max_sample_size={self.max_sample_size}"
)
def set_epoch(self, epoch):
self.epoch = epoch
def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1):
self.max_tokens = max_tokens
self.max_sentences = max_sentences
self.required_batch_size_multiple = required_batch_size_multiple
if isinstance(indices[0], np.ndarray):
batch_list = []
for indice in indices:
batch = super(HubertDataset, self).batch_by_size(indice, max_tokens, max_sentences, required_batch_size_multiple)
batch_list.append(batch)
return batch_list
else:
return super(HubertDataset, self).batch_by_size(indices, max_tokens, max_sentences, required_batch_size_multiple)
def shuffle_batches(self, batches, seed):
if isinstance(batches[0], list):
new_batches = []
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
for batch in batches:
np.random.shuffle(batch)
new_batches.extend(batch)
return new_batches
else:
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
return batches
def get_audio(self, index):
import soundfile as sf
wav_path = os.path.join(self.audio_root, self.audio_names[index])
_path, slice_ptr = parse_path(wav_path)
if len(slice_ptr) == 1:
import kaldiio
feat = kaldiio.load_mat(wav_path)
feat = torch.from_numpy(feat).float()
if self.normalize:
with torch.no_grad():
feat = F.layer_norm(feat, feat.shape[-1])
return feat
else:
if len(slice_ptr) == 2:
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
assert is_sf_audio_data(byte_data)
wav_path = io.BytesIO(byte_data)
for i in range(self.retry_times):
if i < self.retry_times - 1:
try:
wav, cur_sample_rate = sf.read(wav_path)
break
except Exception as e:
logger.warn(f"Fail to load wav for the {i} time")
logger.warn(e)
time.sleep(1)
continue
else:
wav, cur_sample_rate = sf.read(wav_path)
wav = torch.from_numpy(wav).float()
wav = self.postprocess(wav, cur_sample_rate)
return wav
def get_label(self, index, label_idx):
if self.store_labels:
label = self.label_list[label_idx][index]
else:
with open(self.label_paths[label_idx]) as f:
offset_s, offset_e = self.label_offsets_list[label_idx][index]
f.seek(offset_s)
label = f.read(offset_e - offset_s)
if self.tokenizer is not None and self.fine_tuning:
label = self.tokenizer.encode(label)
if self.label_processors is not None:
label = self.label_processors[label_idx](label)
return label
def get_labels(self, index):
return [self.get_label(index, i) for i in range(self.num_labels)]
def __getitem__(self, index):
wav = self.get_audio(index)
labels = self.get_labels(index)
return {"id": index, "source": wav, "label_list": labels}
def __len__(self):
return len(self.wav_sizes)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater(self, samples):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
if self.pad_audio:
audio_size = min(max(audio_sizes), self.max_sample_size)
else:
audio_size = min(min(audio_sizes), self.max_sample_size)
feat_dim = audios[0].size(-1) if audios[0].dim() > 1 else 1
collated_audios, padding_mask, audio_starts = self.collater_audio(
audios, audio_size, feat_dim,
)
targets_by_label = [
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
)
if self.add_decoder_target:
if self.fine_tuning:
decoder_label = [
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
for i in range(targets_list[0].size(0))
]
else:
if self.tokenizer is not None:
decoder_label = [
# Set 48 for translate int to char and avoid \n
torch.cat(
(
torch.tensor(
self.tokenizer.sp.Encode(
"".join(
[chr(j + 48) for j in (
targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]]
).tolist()]
), out_type=int
)
),
torch.tensor([self.tgt_dict.eos()])
), dim=0
).long()
for i in range(targets_list[0].size(0))
]
else:
decoder_label = [
torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
for i in range(targets_list[0].size(0))
]
if self.mbart_style_lang_id:
decoder_label = [
torch.cat((decoder_label[i], torch.tensor([self.tgt_lang_idx])), 0).long()
for i in range(targets_list[0].size(0))
]
dec_ntokens = sum(x.size(0) for x in decoder_label)
decoder_target = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
left_pad=False,
move_eos_to_beginning=False,
)
decoder_target_lengths = torch.tensor(
[x.size(0) for x in decoder_label], dtype=torch.long
)
prev_output_tokens = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx,
left_pad=False,
move_eos_to_beginning=True,
)
if self.tgt_lang_idx is not None and not self.mbart_style_lang_id:
assert (prev_output_tokens[:, 0] != self.tgt_dict.eos()).sum() == 0
prev_output_tokens[:, 0] = self.tgt_lang_idx
net_input = {
"source": collated_audios,
"padding_mask": padding_mask,
"prev_output_tokens": prev_output_tokens,
}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
"decoder_target": decoder_target,
"decoder_target_lengths": decoder_target_lengths,
"dec_ntokens": dec_ntokens,
"lang_idx": self.tgt_lang_idx,
}
else:
net_input = {"source": collated_audios, "padding_mask": padding_mask}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
return batch
def collater_audio(self, audios, audio_size, feat_dim=1):
collated_audios = audios[0].new_zeros(len(audios), audio_size, feat_dim)
padding_mask = (
torch.BoolTensor(collated_audios.shape[0:2]).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, audio in enumerate(audios):
audio = audio.view(-1, feat_dim)
diff = len(audio) - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff, feat_dim), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios.squeeze(-1), padding_mask, audio_starts
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts]
frm_size = int(round(audio_size * s2f))
if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
logger.debug(f"audio_starts={audio_starts}")
logger.debug(f"frame_starts={frm_starts}")
logger.debug(f"frame_size={frm_size}")
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_seq_label(self, targets, pad):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_label(self, targets_by_label, audio_size, audio_starts):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
if label_rate == -1:
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
else:
targets, lengths, ntokens = self.collater_frm_label(
targets, audio_size, audio_starts, label_rate, pad
)
targets_list.append(targets)
lengths_list.append(lengths)
ntokens_list.append(ntokens)
return targets_list, lengths_list, ntokens_list
def num_tokens(self, index):
return self.size(index)
def size(self, index):
if self.pad_audio:
return self.wav_sizes[index]
return min(self.wav_sizes[index], self.max_sample_size)
@property
def sizes(self):
return np.array(self.wav_sizes)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
if len(self.chunk_names) > 0:
logger.info(f"ordered indices for epoch {self.epoch}")
with data_utils.numpy_seed(self.epoch):
self.chunk_order = np.random.permutation(len(self.chunk_names))
chunk_count = 0
tmp_sizes = []
tmp_indices = []
indice = []
for i in self.chunk_order:
chunk_count += 1
start = self.chunk_indices[i]
end = self.chunk_indices[i+1] if i < len(self.chunk_names) - 1 else len(self)
size = list(self.sizes[start:end])
tmp_indices.extend(list(np.arange(start, end)))
tmp_sizes.extend(size)
if chunk_count % 10 == 0 or i == self.chunk_order[0]:
order = [np.random.permutation(len(tmp_indices))]
order.append(
np.minimum(
np.array(tmp_sizes),
self.max_sample_size,
)
)
sort_idx = np.lexsort(order)[::-1]
indice.append(np.array([tmp_indices[k] for k in sort_idx]))
tmp_indices = []
tmp_sizes =[]
return indice
else:
order = [np.random.permutation(len(self))]
order.append(
np.minimum(
np.array(self.sizes),
self.max_sample_size,
)
)
return np.lexsort(order)[::-1]
else:
return np.arange(len(self))
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment