utils.cpp 15.9 KB
Newer Older
moto's avatar
moto committed
1
#include <c10/core/ScalarType.h>
2
3
#include <libtorchaudio/sox/types.h>
#include <libtorchaudio/sox/utils.h>
moto's avatar
moto committed
4
5
#include <sox.h>

Moto Hira's avatar
Moto Hira committed
6
namespace torchaudio::sox {
moto's avatar
moto committed
7

moto-meta's avatar
moto-meta committed
8
9
10
11
12
13
14
15
const std::unordered_set<std::string> UNSUPPORTED_EFFECTS{
    "input",
    "output",
    "spectrogram",
    "noiseprof",
    "noisered",
    "splice"};

moto's avatar
moto committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
void set_seed(const int64_t seed) {
  sox_get_globals()->ranqd1 = static_cast<sox_int32_t>(seed);
}

void set_verbosity(const int64_t verbosity) {
  sox_get_globals()->verbosity = static_cast<unsigned>(verbosity);
}

void set_use_threads(const bool use_threads) {
  sox_get_globals()->use_threads = static_cast<sox_bool>(use_threads);
}

void set_buffer_size(const int64_t buffer_size) {
  sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size);
}

32
33
34
35
int64_t get_buffer_size() {
  return sox_get_globals()->bufsiz;
}

moto's avatar
moto committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
std::vector<std::vector<std::string>> list_effects() {
  std::vector<std::vector<std::string>> effects;
  for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) {
    const sox_effect_handler_t* handler = (*fns)();
    if (handler && handler->name) {
      if (UNSUPPORTED_EFFECTS.find(handler->name) ==
          UNSUPPORTED_EFFECTS.end()) {
        effects.emplace_back(std::vector<std::string>{
            handler->name,
            handler->usage ? std::string(handler->usage) : std::string("")});
      }
    }
  }
  return effects;
}

52
std::vector<std::string> list_write_formats() {
moto's avatar
moto committed
53
54
  std::vector<std::string> formats;
  for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
55
56
    const sox_format_handler_t* handler = fns->fn();
    for (const char* const* names = handler->names; *names; ++names) {
moto-meta's avatar
moto-meta committed
57
      if (!strchr(*names, '/') && handler->write) {
58
        formats.emplace_back(*names);
moto-meta's avatar
moto-meta committed
59
      }
60
61
62
63
64
65
66
67
68
69
    }
  }
  return formats;
}

std::vector<std::string> list_read_formats() {
  std::vector<std::string> formats;
  for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
    const sox_format_handler_t* handler = fns->fn();
    for (const char* const* names = handler->names; *names; ++names) {
moto-meta's avatar
moto-meta committed
70
      if (!strchr(*names, '/') && handler->read) {
moto's avatar
moto committed
71
        formats.emplace_back(*names);
moto-meta's avatar
moto-meta committed
72
      }
moto's avatar
moto committed
73
74
75
76
77
    }
  }
  return formats;
}

moto's avatar
moto committed
78
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
moto's avatar
moto committed
79
80
81
SoxFormat::~SoxFormat() {
  close();
}
82

moto's avatar
moto committed
83
84
85
sox_format_t* SoxFormat::operator->() const noexcept {
  return fd_;
}
86
SoxFormat::operator sox_format_t*() const noexcept {
moto's avatar
moto committed
87
88
89
  return fd_;
}

90
91
92
93
94
95
96
void SoxFormat::close() {
  if (fd_ != nullptr) {
    sox_close(fd_);
    fd_ = nullptr;
  }
}

97
void validate_input_file(const SoxFormat& sf, const std::string& path) {
98
99
100
101
102
103
  TORCH_CHECK(
      static_cast<sox_format_t*>(sf) != nullptr,
      "Error loading audio file: failed to open file " + path);
  TORCH_CHECK(
      sf->encoding.encoding != SOX_ENCODING_UNKNOWN,
      "Error loading audio file: unknown encoding.");
moto's avatar
moto committed
104
105
}

Moto Hira's avatar
Moto Hira committed
106
void validate_input_tensor(const torch::Tensor& tensor) {
107
  TORCH_CHECK(tensor.device().is_cpu(), "Input tensor has to be on CPU.");
108

109
  TORCH_CHECK(tensor.ndimension() == 2, "Input tensor has to be 2D.");
110

111
112
113
114
115
116
117
  switch (tensor.dtype().toScalarType()) {
    case c10::ScalarType::Byte:
    case c10::ScalarType::Short:
    case c10::ScalarType::Int:
    case c10::ScalarType::Float:
      break;
    default:
118
119
      TORCH_CHECK(
          false,
120
          "Input tensor has to be one of float32, int32, int16 or uint8 type.");
121
122
123
  }
}

moto's avatar
moto committed
124
125
126
127
128
129
130
caffe2::TypeMeta get_dtype(
    const sox_encoding_t encoding,
    const unsigned precision) {
  const auto dtype = [&]() {
    switch (encoding) {
      case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
        return torch::kUInt8;
131
      case SOX_ENCODING_SIGN2: // 16-bit, 24-bit, or 32-bit PCM WAV
moto's avatar
moto committed
132
133
134
        switch (precision) {
          case 16:
            return torch::kInt16;
135
          case 24: // Cast 24-bit to 32-bit.
moto's avatar
moto committed
136
137
138
          case 32:
            return torch::kInt32;
          default:
139
140
            TORCH_CHECK(
                false,
141
                "Only 16, 24, and 32 bits are supported for signed PCM.");
moto's avatar
moto committed
142
143
144
145
146
147
148
149
150
151
152
153
        }
      default:
        // default to float32 for the other formats, including
        // 32-bit flaoting-point WAV,
        // MP3,
        // FLAC,
        // VORBIS etc...
        return torch::kFloat32;
    }
  }();
  return c10::scalarTypeToTypeMeta(dtype);
}
154

moto's avatar
moto committed
155
156
157
158
159
160
161
torch::Tensor convert_to_tensor(
    sox_sample_t* buffer,
    const int32_t num_samples,
    const int32_t num_channels,
    const caffe2::TypeMeta dtype,
    const bool normalize,
    const bool channels_first) {
162
  torch::Tensor t;
163
  uint64_t dummy = 0;
164
  SOX_SAMPLE_LOCALS;
moto's avatar
moto committed
165
  if (normalize || dtype == torch::kFloat32) {
166
167
168
169
170
171
    t = torch::empty(
        {num_samples / num_channels, num_channels}, torch::kFloat32);
    auto ptr = t.data_ptr<float_t>();
    for (int32_t i = 0; i < num_samples; ++i) {
      ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy);
    }
moto's avatar
moto committed
172
  } else if (dtype == torch::kInt32) {
173
174
175
    t = torch::from_blob(
            buffer, {num_samples / num_channels, num_channels}, torch::kInt32)
            .clone();
moto's avatar
moto committed
176
  } else if (dtype == torch::kInt16) {
177
178
179
180
181
    t = torch::empty({num_samples / num_channels, num_channels}, torch::kInt16);
    auto ptr = t.data_ptr<int16_t>();
    for (int32_t i = 0; i < num_samples; ++i) {
      ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy);
    }
moto's avatar
moto committed
182
  } else if (dtype == torch::kUInt8) {
183
184
185
186
187
    t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8);
    auto ptr = t.data_ptr<uint8_t>();
    for (int32_t i = 0; i < num_samples; ++i) {
      ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
    }
moto's avatar
moto committed
188
  } else {
189
    TORCH_CHECK(false, "Unsupported dtype: ", dtype);
moto's avatar
moto committed
190
191
192
193
194
195
196
  }
  if (channels_first) {
    t = t.transpose(1, 0);
  }
  return t.contiguous();
}

Moto Hira's avatar
Moto Hira committed
197
const std::string get_filetype(const std::string& path) {
198
199
200
201
202
  std::string ext = path.substr(path.find_last_of(".") + 1);
  std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
  return ext;
}

203
204
205
namespace {

std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
moto-meta's avatar
moto-meta committed
206
    const std::string& format,
207
    caffe2::TypeMeta dtype,
208
209
210
211
212
213
    const Encoding& encoding,
    const BitDepth& bits_per_sample) {
  switch (encoding) {
    case Encoding::NOT_PROVIDED:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
214
215
216
217
218
219
220
221
222
223
          switch (dtype.toScalarType()) {
            case c10::ScalarType::Float:
              return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
            case c10::ScalarType::Int:
              return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
            case c10::ScalarType::Short:
              return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
            case c10::ScalarType::Byte:
              return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
            default:
224
              TORCH_CHECK(false, "Internal Error: Unexpected dtype: ", dtype);
225
          }
226
227
228
229
230
231
232
233
234
        case BitDepth::B8:
          return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
        default:
          return std::make_tuple<>(
              SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
      }
    case Encoding::PCM_SIGNED:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
moto's avatar
moto committed
235
          return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
236
        case BitDepth::B8:
237
238
          TORCH_CHECK(
              false, format, " does not support 8-bit signed PCM encoding.");
239
240
241
242
243
244
245
246
247
248
        default:
          return std::make_tuple<>(
              SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
      }
    case Encoding::PCM_UNSIGNED:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
        case BitDepth::B8:
          return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
        default:
249
250
          TORCH_CHECK(
              false, format, " only supports 8-bit for unsigned PCM encoding.");
251
252
253
254
255
256
257
258
259
      }
    case Encoding::PCM_FLOAT:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
        case BitDepth::B32:
          return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
        case BitDepth::B64:
          return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
        default:
260
261
262
          TORCH_CHECK(
              false,
              format,
263
264
265
266
267
268
269
270
              " only supports 32-bit or 64-bit for floating-point PCM encoding.");
      }
    case Encoding::ULAW:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
        case BitDepth::B8:
          return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
        default:
271
272
          TORCH_CHECK(
              false, format, " only supports 8-bit for mu-law encoding.");
273
274
275
276
277
278
279
      }
    case Encoding::ALAW:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
        case BitDepth::B8:
          return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
        default:
280
281
          TORCH_CHECK(
              false, format, " only supports 8-bit for a-law encoding.");
282
283
      }
    default:
284
285
      TORCH_CHECK(
          false, format, " does not support encoding: " + to_string(encoding));
286
287
288
289
290
  }
}

std::tuple<sox_encoding_t, unsigned> get_save_encoding(
    const std::string& format,
Moto Hira's avatar
Moto Hira committed
291
292
293
    const caffe2::TypeMeta& dtype,
    const c10::optional<std::string>& encoding,
    const c10::optional<int64_t>& bits_per_sample) {
294
295
296
297
298
299
300
301
302
  const Format fmt = get_format_from_string(format);
  const Encoding enc = get_encoding_from_option(encoding);
  const BitDepth bps = get_bit_depth_from_option(bits_per_sample);

  switch (fmt) {
    case Format::WAV:
    case Format::AMB:
      return get_save_encoding_for_wav(format, dtype, enc, bps);
    case Format::MP3:
303
304
305
306
307
308
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "mp3 does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "mp3 does not support `bits_per_sample` option.");
309
      return std::make_tuple<>(SOX_ENCODING_MP3, 16);
310
    case Format::HTK:
311
312
313
314
315
316
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "htk does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "htk does not support `bits_per_sample` option.");
317
      return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
318
    case Format::VORBIS:
319
320
321
322
323
324
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "vorbis does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "vorbis does not support `bits_per_sample` option.");
moto's avatar
moto committed
325
      return std::make_tuple<>(SOX_ENCODING_VORBIS, 0);
326
    case Format::AMR_NB:
327
328
329
330
331
332
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "amr-nb does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "amr-nb does not support `bits_per_sample` option.");
333
334
      return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
    case Format::FLAC:
335
336
337
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "flac does not support `encoding` option.");
338
339
340
      switch (bps) {
        case BitDepth::B32:
        case BitDepth::B64:
341
342
          TORCH_CHECK(
              false, "flac does not support `bits_per_sample` larger than 24.");
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        default:
          return std::make_tuple<>(
              SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
      }
    case Format::SPHERE:
      switch (enc) {
        case Encoding::NOT_PROVIDED:
        case Encoding::PCM_SIGNED:
          switch (bps) {
            case BitDepth::NOT_PROVIDED:
              return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
            default:
              return std::make_tuple<>(
                  SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
          }
        case Encoding::PCM_UNSIGNED:
359
          TORCH_CHECK(false, "sph does not support unsigned integer PCM.");
360
        case Encoding::PCM_FLOAT:
361
          TORCH_CHECK(false, "sph does not support floating point PCM.");
362
363
364
365
366
367
        case Encoding::ULAW:
          switch (bps) {
            case BitDepth::NOT_PROVIDED:
            case BitDepth::B8:
              return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
            default:
368
369
              TORCH_CHECK(
                  false, "sph only supports 8-bit for mu-law encoding.");
370
371
372
373
374
375
376
377
378
379
380
          }
        case Encoding::ALAW:
          switch (bps) {
            case BitDepth::NOT_PROVIDED:
            case BitDepth::B8:
              return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
            default:
              return std::make_tuple<>(
                  SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
          }
        default:
381
382
          TORCH_CHECK(
              false, "sph does not support encoding: ", encoding.value());
383
      }
384
    case Format::GSM:
385
386
387
388
389
390
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "gsm does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "gsm does not support `bits_per_sample` option.");
391
392
      return std::make_tuple<>(SOX_ENCODING_GSM, 16);

393
    default:
394
      TORCH_CHECK(false, "Unsupported format: " + format);
395
396
397
  }
}

Moto Hira's avatar
Moto Hira committed
398
unsigned get_precision(const std::string& filetype, caffe2::TypeMeta dtype) {
moto-meta's avatar
moto-meta committed
399
  if (filetype == "mp3") {
400
    return SOX_UNSPEC;
moto-meta's avatar
moto-meta committed
401
402
  }
  if (filetype == "flac") {
403
    return 24;
moto-meta's avatar
moto-meta committed
404
405
  }
  if (filetype == "ogg" || filetype == "vorbis") {
406
    return SOX_UNSPEC;
moto-meta's avatar
moto-meta committed
407
  }
408
  if (filetype == "wav" || filetype == "amb") {
409
410
411
412
413
414
415
416
417
418
    switch (dtype.toScalarType()) {
      case c10::ScalarType::Byte:
        return 8;
      case c10::ScalarType::Short:
        return 16;
      case c10::ScalarType::Int:
        return 32;
      case c10::ScalarType::Float:
        return 32;
      default:
419
        TORCH_CHECK(false, "Unsupported dtype: ", dtype);
420
    }
421
  }
moto-meta's avatar
moto-meta committed
422
  if (filetype == "sph") {
moto's avatar
moto committed
423
    return 32;
moto-meta's avatar
moto-meta committed
424
  }
425
426
427
  if (filetype == "amr-nb") {
    return 16;
  }
428
  if (filetype == "gsm") {
429
    return 16;
430
431
432
433
  }
  if (filetype == "htk") {
    return 16;
  }
434
  TORCH_CHECK(false, "Unsupported file type: ", filetype);
435
436
}

437
438
} // namespace

439
sox_signalinfo_t get_signalinfo(
440
441
    const torch::Tensor* waveform,
    const int64_t sample_rate,
Moto Hira's avatar
Moto Hira committed
442
    const std::string& filetype,
443
    const bool channels_first) {
444
  return sox_signalinfo_t{
445
      /*rate=*/static_cast<sox_rate_t>(sample_rate),
moto's avatar
moto committed
446
      /*channels=*/
447
448
      static_cast<unsigned>(waveform->size(channels_first ? 0 : 1)),
      /*precision=*/get_precision(filetype, waveform->dtype()),
moto-meta's avatar
moto-meta committed
449
450
      /*length=*/static_cast<uint64_t>(waveform->numel()),
      nullptr};
451
452
}

453
sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
454
  sox_encoding_t encoding = [&]() {
455
456
457
458
459
460
461
462
463
464
    switch (dtype.toScalarType()) {
      case c10::ScalarType::Byte:
        return SOX_ENCODING_UNSIGNED;
      case c10::ScalarType::Short:
        return SOX_ENCODING_SIGN2;
      case c10::ScalarType::Int:
        return SOX_ENCODING_SIGN2;
      case c10::ScalarType::Float:
        return SOX_ENCODING_FLOAT;
      default:
465
        TORCH_CHECK(false, "Unsupported dtype: ", dtype);
466
    }
467
468
  }();
  unsigned bits_per_sample = [&]() {
469
470
471
472
473
474
475
476
477
478
    switch (dtype.toScalarType()) {
      case c10::ScalarType::Byte:
        return 8;
      case c10::ScalarType::Short:
        return 16;
      case c10::ScalarType::Int:
        return 32;
      case c10::ScalarType::Float:
        return 32;
      default:
479
        TORCH_CHECK(false, "Unsupported dtype: ", dtype);
480
    }
481
  }();
moto's avatar
moto committed
482
  return sox_encodinginfo_t{
483
484
      /*encoding=*/encoding,
      /*bits_per_sample=*/bits_per_sample,
moto's avatar
moto committed
485
486
487
488
489
      /*compression=*/HUGE_VAL,
      /*reverse_bytes=*/sox_option_default,
      /*reverse_nibbles=*/sox_option_default,
      /*reverse_bits=*/sox_option_default,
      /*opposite_endian=*/sox_false};
490
}
491

492
sox_encodinginfo_t get_encodinginfo_for_save(
493
    const std::string& format,
Moto Hira's avatar
Moto Hira committed
494
495
496
497
    const caffe2::TypeMeta& dtype,
    const c10::optional<double>& compression,
    const c10::optional<std::string>& encoding,
    const c10::optional<int64_t>& bits_per_sample) {
498
  auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample);
moto's avatar
moto committed
499
  return sox_encodinginfo_t{
500
501
      /*encoding=*/std::get<0>(enc),
      /*bits_per_sample=*/std::get<1>(enc),
moto's avatar
moto committed
502
503
504
505
506
      /*compression=*/compression.value_or(HUGE_VAL),
      /*reverse_bytes=*/sox_option_default,
      /*reverse_nibbles=*/sox_option_default,
      /*reverse_bits=*/sox_option_default,
      /*opposite_endian=*/sox_false};
507
508
}

Moto Hira's avatar
Moto Hira committed
509
} // namespace torchaudio::sox