You need to sign in or sign up before continuing.
io.cpp 4.2 KB
Newer Older
1
2
3
4
5
#include <libtorchaudio/sox/effects.h>
#include <libtorchaudio/sox/effects_chain.h>
#include <libtorchaudio/sox/io.h>
#include <libtorchaudio/sox/types.h>
#include <libtorchaudio/sox/utils.h>
6
7
8

using namespace torch::indexing;

Moto Hira's avatar
Moto Hira committed
9
namespace torchaudio::sox {
10

11
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
12
    const std::string& path,
13
    const std::optional<std::string>& format) {
moto's avatar
moto committed
14
15
16
17
  SoxFormat sf(sox_open_read(
      path.c_str(),
      /*signal=*/nullptr,
      /*encoding=*/nullptr,
18
      /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
moto's avatar
moto committed
19

20
  validate_input_file(sf, path);
21

22
  return std::make_tuple(
moto's avatar
moto committed
23
      static_cast<int64_t>(sf->signal.rate),
24
      static_cast<int64_t>(sf->signal.length / sf->signal.channels),
25
      static_cast<int64_t>(sf->signal.channels),
26
27
      static_cast<int64_t>(sf->encoding.bits_per_sample),
      get_encoding(sf->encoding.encoding));
moto's avatar
moto committed
28
}
29

30
std::vector<std::vector<std::string>> get_effects(
31
32
    const std::optional<int64_t>& frame_offset,
    const std::optional<int64_t>& num_frames) {
33
  const auto offset = frame_offset.value_or(0);
34
35
36
37
  TORCH_CHECK(
      offset >= 0,
      "Invalid argument: frame_offset must be non-negative. Found: ",
      offset);
38
  const auto frames = num_frames.value_or(-1);
39
40
41
  TORCH_CHECK(
      frames > 0 || frames == -1,
      "Invalid argument: num_frames must be -1 or greater than 0.");
42

43
  std::vector<std::vector<std::string>> effects;
44
45
46
47
  if (frames != -1) {
    std::ostringstream os_offset, os_frames;
    os_offset << offset << "s";
    os_frames << "+" << frames << "s";
48
    effects.emplace_back(
49
50
51
52
53
        std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
  } else if (offset != 0) {
    std::ostringstream os_offset;
    os_offset << offset << "s";
    effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
moto's avatar
moto committed
54
  }
55
56
57
  return effects;
}

58
std::tuple<torch::Tensor, int64_t> load_audio_file(
59
    const std::string& path,
60
61
62
63
64
    const std::optional<int64_t>& frame_offset,
    const std::optional<int64_t>& num_frames,
    std::optional<bool> normalize,
    std::optional<bool> channels_first,
    const std::optional<std::string>& format) {
65
  auto effects = get_effects(frame_offset, num_frames);
Moto Hira's avatar
Moto Hira committed
66
  return apply_effects_file(path, effects, normalize, channels_first, format);
67
68
}

69
void save_audio_file(
70
71
72
73
    const std::string& path,
    torch::Tensor tensor,
    int64_t sample_rate,
    bool channels_first,
74
75
76
77
    std::optional<double> compression,
    std::optional<std::string> format,
    std::optional<std::string> encoding,
    std::optional<int64_t> bits_per_sample) {
78
79
  validate_input_tensor(tensor);

moto's avatar
moto committed
80
  const auto filetype = [&]() {
moto-meta's avatar
moto-meta committed
81
    if (format.has_value()) {
moto's avatar
moto committed
82
      return format.value();
moto-meta's avatar
moto-meta committed
83
    }
84
85
    return get_filetype(path);
  }();
86

87
  if (filetype == "amr-nb") {
88
    const auto num_channels = tensor.size(channels_first ? 0 : 1);
89
90
    TORCH_CHECK(
        num_channels == 1, "amr-nb format only supports single channel audio.");
91
92
93
94
  } else if (filetype == "htk") {
    const auto num_channels = tensor.size(channels_first ? 0 : 1);
    TORCH_CHECK(
        num_channels == 1, "htk format only supports single channel audio.");
95
96
97
98
99
100
101
  } else if (filetype == "gsm") {
    const auto num_channels = tensor.size(channels_first ? 0 : 1);
    TORCH_CHECK(
        num_channels == 1, "gsm format only supports single channel audio.");
    TORCH_CHECK(
        sample_rate == 8000,
        "gsm format only supports a sampling rate of 8kHz.");
102
  }
103
104
  const auto signal_info =
      get_signalinfo(&tensor, sample_rate, filetype, channels_first);
105
106
  const auto encoding_info = get_encodinginfo_for_save(
      filetype, tensor.dtype(), compression, encoding, bits_per_sample);
107
108

  SoxFormat sf(sox_open_write(
109
      path.c_str(),
110
111
112
113
114
115
      &signal_info,
      &encoding_info,
      /*filetype=*/filetype.c_str(),
      /*oob=*/nullptr,
      /*overwrite_permitted=*/nullptr));

116
117
118
119
  TORCH_CHECK(
      static_cast<sox_format_t*>(sf) != nullptr,
      "Error saving audio file: failed to open file ",
      path);
120

Moto Hira's avatar
Moto Hira committed
121
  SoxEffectsChain chain(
122
      /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
123
      /*output_encoding=*/sf->encoding);
124
  chain.addInputTensor(&tensor, sample_rate, channels_first);
125
126
  chain.addOutputFile(sf);
  chain.run();
127
}
Moto Hira's avatar
Moto Hira committed
128
} // namespace torchaudio::sox