utils.cpp 15.7 KB
Newer Older
moto's avatar
moto committed
1
#include <c10/core/ScalarType.h>
2
3
#include <libtorchaudio/sox/types.h>
#include <libtorchaudio/sox/utils.h>
moto's avatar
moto committed
4
5
#include <sox.h>

Moto Hira's avatar
Moto Hira committed
6
namespace torchaudio::sox {
moto's avatar
moto committed
7

moto's avatar
moto committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void set_seed(const int64_t seed) {
  sox_get_globals()->ranqd1 = static_cast<sox_int32_t>(seed);
}

void set_verbosity(const int64_t verbosity) {
  sox_get_globals()->verbosity = static_cast<unsigned>(verbosity);
}

void set_use_threads(const bool use_threads) {
  sox_get_globals()->use_threads = static_cast<sox_bool>(use_threads);
}

void set_buffer_size(const int64_t buffer_size) {
  sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size);
}

24
25
26
27
int64_t get_buffer_size() {
  return sox_get_globals()->bufsiz;
}

moto's avatar
moto committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
std::vector<std::vector<std::string>> list_effects() {
  std::vector<std::vector<std::string>> effects;
  for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) {
    const sox_effect_handler_t* handler = (*fns)();
    if (handler && handler->name) {
      if (UNSUPPORTED_EFFECTS.find(handler->name) ==
          UNSUPPORTED_EFFECTS.end()) {
        effects.emplace_back(std::vector<std::string>{
            handler->name,
            handler->usage ? std::string(handler->usage) : std::string("")});
      }
    }
  }
  return effects;
}

44
std::vector<std::string> list_write_formats() {
moto's avatar
moto committed
45
46
  std::vector<std::string> formats;
  for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    const sox_format_handler_t* handler = fns->fn();
    for (const char* const* names = handler->names; *names; ++names) {
      if (!strchr(*names, '/') && handler->write)
        formats.emplace_back(*names);
    }
  }
  return formats;
}

std::vector<std::string> list_read_formats() {
  std::vector<std::string> formats;
  for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
    const sox_format_handler_t* handler = fns->fn();
    for (const char* const* names = handler->names; *names; ++names) {
      if (!strchr(*names, '/') && handler->read)
moto's avatar
moto committed
62
63
64
65
66
67
        formats.emplace_back(*names);
    }
  }
  return formats;
}

moto's avatar
moto committed
68
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
moto's avatar
moto committed
69
70
71
SoxFormat::~SoxFormat() {
  close();
}
72

moto's avatar
moto committed
73
74
75
sox_format_t* SoxFormat::operator->() const noexcept {
  return fd_;
}
76
SoxFormat::operator sox_format_t*() const noexcept {
moto's avatar
moto committed
77
78
79
  return fd_;
}

80
81
82
83
84
85
86
void SoxFormat::close() {
  if (fd_ != nullptr) {
    sox_close(fd_);
    fd_ = nullptr;
  }
}

87
void validate_input_file(const SoxFormat& sf, const std::string& path) {
88
89
90
91
92
93
  TORCH_CHECK(
      static_cast<sox_format_t*>(sf) != nullptr,
      "Error loading audio file: failed to open file " + path);
  TORCH_CHECK(
      sf->encoding.encoding != SOX_ENCODING_UNKNOWN,
      "Error loading audio file: unknown encoding.");
moto's avatar
moto committed
94
95
}

Moto Hira's avatar
Moto Hira committed
96
void validate_input_tensor(const torch::Tensor& tensor) {
97
  TORCH_CHECK(tensor.device().is_cpu(), "Input tensor has to be on CPU.");
98

99
  TORCH_CHECK(tensor.ndimension() == 2, "Input tensor has to be 2D.");
100

101
102
103
104
105
106
107
  switch (tensor.dtype().toScalarType()) {
    case c10::ScalarType::Byte:
    case c10::ScalarType::Short:
    case c10::ScalarType::Int:
    case c10::ScalarType::Float:
      break;
    default:
108
109
      TORCH_CHECK(
          false,
110
          "Input tensor has to be one of float32, int32, int16 or uint8 type.");
111
112
113
  }
}

moto's avatar
moto committed
114
115
116
117
118
119
120
caffe2::TypeMeta get_dtype(
    const sox_encoding_t encoding,
    const unsigned precision) {
  const auto dtype = [&]() {
    switch (encoding) {
      case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
        return torch::kUInt8;
121
      case SOX_ENCODING_SIGN2: // 16-bit, 24-bit, or 32-bit PCM WAV
moto's avatar
moto committed
122
123
124
        switch (precision) {
          case 16:
            return torch::kInt16;
125
          case 24: // Cast 24-bit to 32-bit.
moto's avatar
moto committed
126
127
128
          case 32:
            return torch::kInt32;
          default:
129
130
            TORCH_CHECK(
                false,
131
                "Only 16, 24, and 32 bits are supported for signed PCM.");
moto's avatar
moto committed
132
133
134
135
136
137
138
139
140
141
142
143
        }
      default:
        // default to float32 for the other formats, including
        // 32-bit flaoting-point WAV,
        // MP3,
        // FLAC,
        // VORBIS etc...
        return torch::kFloat32;
    }
  }();
  return c10::scalarTypeToTypeMeta(dtype);
}
144

moto's avatar
moto committed
145
146
147
148
149
150
151
torch::Tensor convert_to_tensor(
    sox_sample_t* buffer,
    const int32_t num_samples,
    const int32_t num_channels,
    const caffe2::TypeMeta dtype,
    const bool normalize,
    const bool channels_first) {
152
  torch::Tensor t;
153
  uint64_t dummy = 0;
154
  SOX_SAMPLE_LOCALS;
moto's avatar
moto committed
155
  if (normalize || dtype == torch::kFloat32) {
156
157
158
159
160
161
    t = torch::empty(
        {num_samples / num_channels, num_channels}, torch::kFloat32);
    auto ptr = t.data_ptr<float_t>();
    for (int32_t i = 0; i < num_samples; ++i) {
      ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy);
    }
moto's avatar
moto committed
162
  } else if (dtype == torch::kInt32) {
163
164
165
    t = torch::from_blob(
            buffer, {num_samples / num_channels, num_channels}, torch::kInt32)
            .clone();
moto's avatar
moto committed
166
  } else if (dtype == torch::kInt16) {
167
168
169
170
171
    t = torch::empty({num_samples / num_channels, num_channels}, torch::kInt16);
    auto ptr = t.data_ptr<int16_t>();
    for (int32_t i = 0; i < num_samples; ++i) {
      ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy);
    }
moto's avatar
moto committed
172
  } else if (dtype == torch::kUInt8) {
173
174
175
176
177
    t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8);
    auto ptr = t.data_ptr<uint8_t>();
    for (int32_t i = 0; i < num_samples; ++i) {
      ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
    }
moto's avatar
moto committed
178
  } else {
179
    TORCH_CHECK(false, "Unsupported dtype: ", dtype);
moto's avatar
moto committed
180
181
182
183
184
185
186
  }
  if (channels_first) {
    t = t.transpose(1, 0);
  }
  return t.contiguous();
}

Moto Hira's avatar
Moto Hira committed
187
const std::string get_filetype(const std::string& path) {
188
189
190
191
192
  std::string ext = path.substr(path.find_last_of(".") + 1);
  std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
  return ext;
}

193
194
195
196
namespace {

std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
    const std::string format,
197
    caffe2::TypeMeta dtype,
198
199
200
201
202
203
    const Encoding& encoding,
    const BitDepth& bits_per_sample) {
  switch (encoding) {
    case Encoding::NOT_PROVIDED:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
204
205
206
207
208
209
210
211
212
213
          switch (dtype.toScalarType()) {
            case c10::ScalarType::Float:
              return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
            case c10::ScalarType::Int:
              return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
            case c10::ScalarType::Short:
              return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
            case c10::ScalarType::Byte:
              return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
            default:
214
              TORCH_CHECK(false, "Internal Error: Unexpected dtype: ", dtype);
215
          }
216
217
218
219
220
221
222
223
224
        case BitDepth::B8:
          return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
        default:
          return std::make_tuple<>(
              SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
      }
    case Encoding::PCM_SIGNED:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
moto's avatar
moto committed
225
          return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
226
        case BitDepth::B8:
227
228
          TORCH_CHECK(
              false, format, " does not support 8-bit signed PCM encoding.");
229
230
231
232
233
234
235
236
237
238
        default:
          return std::make_tuple<>(
              SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
      }
    case Encoding::PCM_UNSIGNED:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
        case BitDepth::B8:
          return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
        default:
239
240
          TORCH_CHECK(
              false, format, " only supports 8-bit for unsigned PCM encoding.");
241
242
243
244
245
246
247
248
249
      }
    case Encoding::PCM_FLOAT:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
        case BitDepth::B32:
          return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
        case BitDepth::B64:
          return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
        default:
250
251
252
          TORCH_CHECK(
              false,
              format,
253
254
255
256
257
258
259
260
              " only supports 32-bit or 64-bit for floating-point PCM encoding.");
      }
    case Encoding::ULAW:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
        case BitDepth::B8:
          return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
        default:
261
262
          TORCH_CHECK(
              false, format, " only supports 8-bit for mu-law encoding.");
263
264
265
266
267
268
269
      }
    case Encoding::ALAW:
      switch (bits_per_sample) {
        case BitDepth::NOT_PROVIDED:
        case BitDepth::B8:
          return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
        default:
270
271
          TORCH_CHECK(
              false, format, " only supports 8-bit for a-law encoding.");
272
273
      }
    default:
274
275
      TORCH_CHECK(
          false, format, " does not support encoding: " + to_string(encoding));
276
277
278
279
280
  }
}

std::tuple<sox_encoding_t, unsigned> get_save_encoding(
    const std::string& format,
Moto Hira's avatar
Moto Hira committed
281
282
283
    const caffe2::TypeMeta& dtype,
    const c10::optional<std::string>& encoding,
    const c10::optional<int64_t>& bits_per_sample) {
284
285
286
287
288
289
290
291
292
  const Format fmt = get_format_from_string(format);
  const Encoding enc = get_encoding_from_option(encoding);
  const BitDepth bps = get_bit_depth_from_option(bits_per_sample);

  switch (fmt) {
    case Format::WAV:
    case Format::AMB:
      return get_save_encoding_for_wav(format, dtype, enc, bps);
    case Format::MP3:
293
294
295
296
297
298
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "mp3 does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "mp3 does not support `bits_per_sample` option.");
299
      return std::make_tuple<>(SOX_ENCODING_MP3, 16);
300
    case Format::HTK:
301
302
303
304
305
306
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "htk does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "htk does not support `bits_per_sample` option.");
307
      return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
308
    case Format::VORBIS:
309
310
311
312
313
314
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "vorbis does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "vorbis does not support `bits_per_sample` option.");
moto's avatar
moto committed
315
      return std::make_tuple<>(SOX_ENCODING_VORBIS, 0);
316
    case Format::AMR_NB:
317
318
319
320
321
322
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "amr-nb does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "amr-nb does not support `bits_per_sample` option.");
323
324
      return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
    case Format::FLAC:
325
326
327
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "flac does not support `encoding` option.");
328
329
330
      switch (bps) {
        case BitDepth::B32:
        case BitDepth::B64:
331
332
          TORCH_CHECK(
              false, "flac does not support `bits_per_sample` larger than 24.");
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        default:
          return std::make_tuple<>(
              SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
      }
    case Format::SPHERE:
      switch (enc) {
        case Encoding::NOT_PROVIDED:
        case Encoding::PCM_SIGNED:
          switch (bps) {
            case BitDepth::NOT_PROVIDED:
              return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
            default:
              return std::make_tuple<>(
                  SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
          }
        case Encoding::PCM_UNSIGNED:
349
          TORCH_CHECK(false, "sph does not support unsigned integer PCM.");
350
        case Encoding::PCM_FLOAT:
351
          TORCH_CHECK(false, "sph does not support floating point PCM.");
352
353
354
355
356
357
        case Encoding::ULAW:
          switch (bps) {
            case BitDepth::NOT_PROVIDED:
            case BitDepth::B8:
              return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
            default:
358
359
              TORCH_CHECK(
                  false, "sph only supports 8-bit for mu-law encoding.");
360
361
362
363
364
365
366
367
368
369
370
          }
        case Encoding::ALAW:
          switch (bps) {
            case BitDepth::NOT_PROVIDED:
            case BitDepth::B8:
              return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
            default:
              return std::make_tuple<>(
                  SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
          }
        default:
371
372
          TORCH_CHECK(
              false, "sph does not support encoding: ", encoding.value());
373
      }
374
    case Format::GSM:
375
376
377
378
379
380
      TORCH_CHECK(
          enc == Encoding::NOT_PROVIDED,
          "gsm does not support `encoding` option.");
      TORCH_CHECK(
          bps == BitDepth::NOT_PROVIDED,
          "gsm does not support `bits_per_sample` option.");
381
382
      return std::make_tuple<>(SOX_ENCODING_GSM, 16);

383
    default:
384
      TORCH_CHECK(false, "Unsupported format: " + format);
385
386
387
  }
}

Moto Hira's avatar
Moto Hira committed
388
unsigned get_precision(const std::string& filetype, caffe2::TypeMeta dtype) {
389
390
391
392
393
394
  if (filetype == "mp3")
    return SOX_UNSPEC;
  if (filetype == "flac")
    return 24;
  if (filetype == "ogg" || filetype == "vorbis")
    return SOX_UNSPEC;
395
  if (filetype == "wav" || filetype == "amb") {
396
397
398
399
400
401
402
403
404
405
    switch (dtype.toScalarType()) {
      case c10::ScalarType::Byte:
        return 8;
      case c10::ScalarType::Short:
        return 16;
      case c10::ScalarType::Int:
        return 32;
      case c10::ScalarType::Float:
        return 32;
      default:
406
        TORCH_CHECK(false, "Unsupported dtype: ", dtype);
407
    }
408
  }
moto's avatar
moto committed
409
410
  if (filetype == "sph")
    return 32;
411
412
413
  if (filetype == "amr-nb") {
    return 16;
  }
414
  if (filetype == "gsm") {
415
    return 16;
416
417
418
419
  }
  if (filetype == "htk") {
    return 16;
  }
420
  TORCH_CHECK(false, "Unsupported file type: ", filetype);
421
422
}

423
424
} // namespace

425
sox_signalinfo_t get_signalinfo(
426
427
    const torch::Tensor* waveform,
    const int64_t sample_rate,
Moto Hira's avatar
Moto Hira committed
428
    const std::string& filetype,
429
    const bool channels_first) {
430
  return sox_signalinfo_t{
431
      /*rate=*/static_cast<sox_rate_t>(sample_rate),
moto's avatar
moto committed
432
      /*channels=*/
433
434
435
      static_cast<unsigned>(waveform->size(channels_first ? 0 : 1)),
      /*precision=*/get_precision(filetype, waveform->dtype()),
      /*length=*/static_cast<uint64_t>(waveform->numel())};
436
437
}

438
sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
439
  sox_encoding_t encoding = [&]() {
440
441
442
443
444
445
446
447
448
449
    switch (dtype.toScalarType()) {
      case c10::ScalarType::Byte:
        return SOX_ENCODING_UNSIGNED;
      case c10::ScalarType::Short:
        return SOX_ENCODING_SIGN2;
      case c10::ScalarType::Int:
        return SOX_ENCODING_SIGN2;
      case c10::ScalarType::Float:
        return SOX_ENCODING_FLOAT;
      default:
450
        TORCH_CHECK(false, "Unsupported dtype: ", dtype);
451
    }
452
453
  }();
  unsigned bits_per_sample = [&]() {
454
455
456
457
458
459
460
461
462
463
    switch (dtype.toScalarType()) {
      case c10::ScalarType::Byte:
        return 8;
      case c10::ScalarType::Short:
        return 16;
      case c10::ScalarType::Int:
        return 32;
      case c10::ScalarType::Float:
        return 32;
      default:
464
        TORCH_CHECK(false, "Unsupported dtype: ", dtype);
465
    }
466
  }();
moto's avatar
moto committed
467
  return sox_encodinginfo_t{
468
469
      /*encoding=*/encoding,
      /*bits_per_sample=*/bits_per_sample,
moto's avatar
moto committed
470
471
472
473
474
      /*compression=*/HUGE_VAL,
      /*reverse_bytes=*/sox_option_default,
      /*reverse_nibbles=*/sox_option_default,
      /*reverse_bits=*/sox_option_default,
      /*opposite_endian=*/sox_false};
475
}
476

477
sox_encodinginfo_t get_encodinginfo_for_save(
478
    const std::string& format,
Moto Hira's avatar
Moto Hira committed
479
480
481
482
    const caffe2::TypeMeta& dtype,
    const c10::optional<double>& compression,
    const c10::optional<std::string>& encoding,
    const c10::optional<int64_t>& bits_per_sample) {
483
  auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample);
moto's avatar
moto committed
484
  return sox_encodinginfo_t{
485
486
      /*encoding=*/std::get<0>(enc),
      /*bits_per_sample=*/std::get<1>(enc),
moto's avatar
moto committed
487
488
489
490
491
      /*compression=*/compression.value_or(HUGE_VAL),
      /*reverse_bytes=*/sox_option_default,
      /*reverse_nibbles=*/sox_option_default,
      /*reverse_bits=*/sox_option_default,
      /*opposite_endian=*/sox_false};
492
493
}

Moto Hira's avatar
Moto Hira committed
494
} // namespace torchaudio::sox