pybind.cpp 16.4 KB
Newer Older
1
2
3
#include <libtorio/ffmpeg/hw_context.h>
#include <libtorio/ffmpeg/stream_reader/stream_reader.h>
#include <libtorio/ffmpeg/stream_writer/stream_writer.h>
4
5
#include <torch/extension.h>

6
namespace torchaudio::io {
7
8
namespace {

9
10
11
std::map<std::string, std::tuple<int64_t, int64_t, int64_t>> get_versions() {
  std::map<std::string, std::tuple<int64_t, int64_t, int64_t>> ret;

12
13
14
15
16
17
18
19
20
#define add_version(NAME)            \
  {                                  \
    int ver = NAME##_version();      \
    ret.emplace(                     \
        "lib" #NAME,                 \
        std::make_tuple<>(           \
            AV_VERSION_MAJOR(ver),   \
            AV_VERSION_MINOR(ver),   \
            AV_VERSION_MICRO(ver))); \
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  }

  add_version(avutil);
  add_version(avcodec);
  add_version(avformat);
  add_version(avfilter);
  add_version(avdevice);
  return ret;

#undef add_version
}

std::map<std::string, std::string> get_demuxers(bool req_device) {
  std::map<std::string, std::string> ret;
  const AVInputFormat* fmt = nullptr;
  void* i = nullptr;
37
  while ((fmt = av_demuxer_iterate(&i))) {
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    assert(fmt);
    bool is_device = [&]() {
      const AVClass* avclass = fmt->priv_class;
      return avclass && AV_IS_INPUT_DEVICE(avclass->category);
    }();
    if (req_device == is_device) {
      ret.emplace(fmt->name, fmt->long_name);
    }
  }
  return ret;
}

std::map<std::string, std::string> get_muxers(bool req_device) {
  std::map<std::string, std::string> ret;
  const AVOutputFormat* fmt = nullptr;
  void* i = nullptr;
54
  while ((fmt = av_muxer_iterate(&i))) {
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    assert(fmt);
    bool is_device = [&]() {
      const AVClass* avclass = fmt->priv_class;
      return avclass && AV_IS_OUTPUT_DEVICE(avclass->category);
    }();
    if (req_device == is_device) {
      ret.emplace(fmt->name, fmt->long_name);
    }
  }
  return ret;
}

std::map<std::string, std::string> get_codecs(
    AVMediaType type,
    bool req_encoder) {
  const AVCodec* c = nullptr;
  void* i = nullptr;
  std::map<std::string, std::string> ret;
73
  while ((c = av_codec_iterate(&i))) {
74
    assert(c);
75
76
    if ((req_encoder && av_codec_is_encoder(c)) ||
        (!req_encoder && av_codec_is_decoder(c))) {
77
78
79
80
81
82
83
84
85
86
87
88
      if (c->type == type && c->name) {
        ret.emplace(c->name, c->long_name ? c->long_name : "");
      }
    }
  }
  return ret;
}

std::vector<std::string> get_protocols(bool output) {
  void* opaque = nullptr;
  const char* name = nullptr;
  std::vector<std::string> ret;
89
  while ((name = avio_enum_protocols(&opaque, output))) {
90
91
92
93
94
95
96
    assert(name);
    ret.emplace_back(name);
  }
  return ret;
}

std::string get_build_config() {
97
  return avcodec_configuration();
98
99
}

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
//////////////////////////////////////////////////////////////////////////////
// StreamReader/Writer FileObj
//////////////////////////////////////////////////////////////////////////////

struct FileObj {
  py::object fileobj;
  int buffer_size;
};

namespace {

static int read_func(void* opaque, uint8_t* buf, int buf_size) {
  FileObj* fileobj = static_cast<FileObj*>(opaque);
  buf_size = FFMIN(buf_size, fileobj->buffer_size);

  int num_read = 0;
  while (num_read < buf_size) {
    int request = buf_size - num_read;
    auto chunk = static_cast<std::string>(
        static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
    auto chunk_len = chunk.length();
    if (chunk_len == 0) {
      break;
    }
    TORCH_CHECK(
        chunk_len <= request,
        "Requested up to ",
        request,
        " bytes but, received ",
        chunk_len,
        " bytes. The given object does not confirm to read protocol of file object.");
    memcpy(buf, chunk.data(), chunk_len);
    buf += chunk_len;
    num_read += static_cast<int>(chunk_len);
  }
  return num_read == 0 ? AVERROR_EOF : num_read;
}

static int write_func(void* opaque, uint8_t* buf, int buf_size) {
  FileObj* fileobj = static_cast<FileObj*>(opaque);
  buf_size = FFMIN(buf_size, fileobj->buffer_size);

  py::bytes b(reinterpret_cast<const char*>(buf), buf_size);
  // TODO: check the return value
  fileobj->fileobj.attr("write")(b);
  return buf_size;
}

static int64_t seek_func(void* opaque, int64_t offset, int whence) {
  // We do not know the file size.
  if (whence == AVSEEK_SIZE) {
    return AVERROR(EIO);
  }
  FileObj* fileobj = static_cast<FileObj*>(opaque);
  return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}

} // namespace

struct StreamReaderFileObj : private FileObj, public StreamReaderCustomIO {
moto's avatar
moto committed
160
161
162
163
  StreamReaderFileObj(
      py::object fileobj,
      const c10::optional<std::string>& format,
      const c10::optional<std::map<std::string, std::string>>& option,
164
165
166
167
168
169
170
171
172
      int buffer_size)
      : FileObj{fileobj, buffer_size},
        StreamReaderCustomIO(
            this,
            format,
            buffer_size,
            read_func,
            py::hasattr(fileobj, "seek") ? &seek_func : nullptr,
            option) {}
moto's avatar
moto committed
173
174
};

175
struct StreamWriterFileObj : private FileObj, public StreamWriterCustomIO {
moto's avatar
moto committed
176
177
178
  StreamWriterFileObj(
      py::object fileobj,
      const c10::optional<std::string>& format,
179
180
181
182
183
184
185
186
      int buffer_size)
      : FileObj{fileobj, buffer_size},
        StreamWriterCustomIO(
            this,
            format,
            buffer_size,
            write_func,
            py::hasattr(fileobj, "seek") ? &seek_func : nullptr) {}
moto's avatar
moto committed
187
188
};

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
//////////////////////////////////////////////////////////////////////////////
// StreamReader/Writer Bytes
//////////////////////////////////////////////////////////////////////////////
struct BytesWrapper {
  std::string_view src;
  size_t index = 0;
};

static int read_bytes(void* opaque, uint8_t* buf, int buf_size) {
  BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);

  auto num_read = FFMIN(wrapper->src.size() - wrapper->index, buf_size);
  if (num_read == 0) {
    return AVERROR_EOF;
  }
  auto head = wrapper->src.data() + wrapper->index;
  memcpy(buf, head, num_read);
  wrapper->index += num_read;
  return num_read;
}

static int64_t seek_bytes(void* opaque, int64_t offset, int whence) {
  BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);
  if (whence == AVSEEK_SIZE) {
    return wrapper->src.size();
  }

  if (whence == SEEK_SET) {
    wrapper->index = offset;
  } else if (whence == SEEK_CUR) {
    wrapper->index += offset;
  } else if (whence == SEEK_END) {
    wrapper->index = wrapper->src.size() + offset;
  } else {
    TORCH_INTERNAL_ASSERT(false, "Unexpected whence value: ", whence);
  }
  return static_cast<int64_t>(wrapper->index);
}

struct StreamReaderBytes : private BytesWrapper, public StreamReaderCustomIO {
  StreamReaderBytes(
      std::string_view src,
      const c10::optional<std::string>& format,
      const c10::optional<std::map<std::string, std::string>>& option,
      int64_t buffer_size)
      : BytesWrapper{src},
        StreamReaderCustomIO(
            this,
            format,
            buffer_size,
            read_bytes,
            seek_bytes,
            option) {}
};

244
245
246
247
248
#ifndef TORCHAUDIO_FFMPEG_EXT_NAME
#error TORCHAUDIO_FFMPEG_EXT_NAME must be defined.
#endif

PYBIND11_MODULE(TORCHAUDIO_FFMPEG_EXT_NAME, m) {
249
250
251
  m.def("init", []() { avdevice_register_all(); });
  m.def("get_log_level", []() { return av_log_get_level(); });
  m.def("set_log_level", [](int level) { av_log_set_level(level); });
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
  m.def("get_versions", &get_versions);
  m.def("get_muxers", []() { return get_muxers(false); });
  m.def("get_demuxers", []() { return get_demuxers(false); });
  m.def("get_input_devices", []() { return get_demuxers(true); });
  m.def("get_build_config", &get_build_config);
  m.def("get_output_devices", []() { return get_muxers(true); });
  m.def("get_audio_decoders", []() {
    return get_codecs(AVMEDIA_TYPE_AUDIO, false);
  });
  m.def("get_audio_encoders", []() {
    return get_codecs(AVMEDIA_TYPE_AUDIO, true);
  });
  m.def("get_video_decoders", []() {
    return get_codecs(AVMEDIA_TYPE_VIDEO, false);
  });
  m.def("get_video_encoders", []() {
    return get_codecs(AVMEDIA_TYPE_VIDEO, true);
  });
  m.def("get_input_protocols", []() { return get_protocols(false); });
  m.def("get_output_protocols", []() { return get_protocols(true); });
moto's avatar
moto committed
272
  m.def("clear_cuda_context_cache", &clear_cuda_context_cache);
273

274
275
276
  py::class_<Chunk>(m, "Chunk", py::module_local())
      .def_readwrite("frames", &Chunk::frames)
      .def_readwrite("pts", &Chunk::pts);
277
  py::class_<CodecConfig>(m, "CodecConfig", py::module_local())
278
      .def(py::init<int, int, const c10::optional<int>&, int, int>());
279
280
281
282
283
284
285
286
287
288
289
290
  py::class_<StreamWriter>(m, "StreamWriter", py::module_local())
      .def(py::init<const std::string&, const c10::optional<std::string>&>())
      .def("set_metadata", &StreamWriter::set_metadata)
      .def("add_audio_stream", &StreamWriter::add_audio_stream)
      .def("add_video_stream", &StreamWriter::add_video_stream)
      .def("dump_format", &StreamWriter::dump_format)
      .def("open", &StreamWriter::open)
      .def("write_audio_chunk", &StreamWriter::write_audio_chunk)
      .def("write_video_chunk", &StreamWriter::write_video_chunk)
      .def("flush", &StreamWriter::flush)
      .def("close", &StreamWriter::close);
  py::class_<StreamWriterFileObj>(m, "StreamWriterFileObj", py::module_local())
291
292
293
294
295
296
297
298
299
300
      .def(py::init<py::object, const c10::optional<std::string>&, int64_t>())
      .def("set_metadata", &StreamWriterFileObj::set_metadata)
      .def("add_audio_stream", &StreamWriterFileObj::add_audio_stream)
      .def("add_video_stream", &StreamWriterFileObj::add_video_stream)
      .def("dump_format", &StreamWriterFileObj::dump_format)
      .def("open", &StreamWriterFileObj::open)
      .def("write_audio_chunk", &StreamWriterFileObj::write_audio_chunk)
      .def("write_video_chunk", &StreamWriterFileObj::write_video_chunk)
      .def("flush", &StreamWriterFileObj::flush)
      .def("close", &StreamWriterFileObj::close);
301
302
  py::class_<OutputStreamInfo>(m, "OutputStreamInfo", py::module_local())
      .def_readonly("source_index", &OutputStreamInfo::source_index)
303
304
305
306
      .def_readonly("filter_description", &OutputStreamInfo::filter_description)
      .def_property_readonly(
          "media_type",
          [](const OutputStreamInfo& o) -> std::string {
307
            return av_get_media_type_string(o.media_type);
308
309
310
311
312
313
          })
      .def_property_readonly(
          "format",
          [](const OutputStreamInfo& o) -> std::string {
            switch (o.media_type) {
              case AVMEDIA_TYPE_AUDIO:
314
                return av_get_sample_fmt_name((AVSampleFormat)(o.format));
315
              case AVMEDIA_TYPE_VIDEO:
316
                return av_get_pix_fmt_name((AVPixelFormat)(o.format));
317
318
319
320
              default:
                TORCH_INTERNAL_ASSERT(
                    false,
                    "FilterGraph is returning unexpected media type: ",
321
                    av_get_media_type_string(o.media_type));
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
            }
          })
      .def_readonly("sample_rate", &OutputStreamInfo::sample_rate)
      .def_readonly("num_channels", &OutputStreamInfo::num_channels)
      .def_readonly("width", &OutputStreamInfo::width)
      .def_readonly("height", &OutputStreamInfo::height)
      .def_property_readonly(
          "frame_rate", [](const OutputStreamInfo& o) -> double {
            if (o.frame_rate.den == 0) {
              TORCH_WARN(
                  "Invalid frame rate is found: ",
                  o.frame_rate.num,
                  "/",
                  o.frame_rate.den);
              return -1;
            }
            return static_cast<double>(o.frame_rate.num) / o.frame_rate.den;
          });
340
341
342
343
  py::class_<SrcStreamInfo>(m, "SourceStreamInfo", py::module_local())
      .def_property_readonly(
          "media_type",
          [](const SrcStreamInfo& s) {
344
            return av_get_media_type_string(s.media_type);
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
          })
      .def_readonly("codec_name", &SrcStreamInfo::codec_name)
      .def_readonly("codec_long_name", &SrcStreamInfo::codec_long_name)
      .def_readonly("format", &SrcStreamInfo::fmt_name)
      .def_readonly("bit_rate", &SrcStreamInfo::bit_rate)
      .def_readonly("num_frames", &SrcStreamInfo::num_frames)
      .def_readonly("bits_per_sample", &SrcStreamInfo::bits_per_sample)
      .def_readonly("metadata", &SrcStreamInfo::metadata)
      .def_readonly("sample_rate", &SrcStreamInfo::sample_rate)
      .def_readonly("num_channels", &SrcStreamInfo::num_channels)
      .def_readonly("width", &SrcStreamInfo::width)
      .def_readonly("height", &SrcStreamInfo::height)
      .def_readonly("frame_rate", &SrcStreamInfo::frame_rate);
  py::class_<StreamReader>(m, "StreamReader", py::module_local())
      .def(py::init<
           const std::string&,
           const c10::optional<std::string>&,
           const c10::optional<OptionDict>&>())
      .def("num_src_streams", &StreamReader::num_src_streams)
      .def("num_out_streams", &StreamReader::num_out_streams)
      .def("find_best_audio_stream", &StreamReader::find_best_audio_stream)
      .def("find_best_video_stream", &StreamReader::find_best_video_stream)
      .def("get_metadata", &StreamReader::get_metadata)
      .def("get_src_stream_info", &StreamReader::get_src_stream_info)
      .def("get_out_stream_info", &StreamReader::get_out_stream_info)
      .def("seek", &StreamReader::seek)
      .def("add_audio_stream", &StreamReader::add_audio_stream)
      .def("add_video_stream", &StreamReader::add_video_stream)
      .def("remove_stream", &StreamReader::remove_stream)
      .def(
          "process_packet",
          py::overload_cast<const c10::optional<double>&, const double>(
              &StreamReader::process_packet))
      .def("process_all_packets", &StreamReader::process_all_packets)
      .def("fill_buffer", &StreamReader::fill_buffer)
      .def("is_buffer_ready", &StreamReader::is_buffer_ready)
      .def("pop_chunks", &StreamReader::pop_chunks);
  py::class_<StreamReaderFileObj>(m, "StreamReaderFileObj", py::module_local())
383
384
385
      .def(py::init<
           py::object,
           const c10::optional<std::string>&,
386
           const c10::optional<OptionDict>&,
387
388
389
390
391
392
393
394
395
           int64_t>())
      .def("num_src_streams", &StreamReaderFileObj::num_src_streams)
      .def("num_out_streams", &StreamReaderFileObj::num_out_streams)
      .def(
          "find_best_audio_stream",
          &StreamReaderFileObj::find_best_audio_stream)
      .def(
          "find_best_video_stream",
          &StreamReaderFileObj::find_best_video_stream)
moto's avatar
moto committed
396
      .def("get_metadata", &StreamReaderFileObj::get_metadata)
397
      .def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info)
398
399
400
401
402
      .def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info)
      .def("seek", &StreamReaderFileObj::seek)
      .def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
      .def("add_video_stream", &StreamReaderFileObj::add_video_stream)
      .def("remove_stream", &StreamReaderFileObj::remove_stream)
403
404
405
406
      .def(
          "process_packet",
          py::overload_cast<const c10::optional<double>&, const double>(
              &StreamReader::process_packet))
407
      .def("process_all_packets", &StreamReaderFileObj::process_all_packets)
408
      .def("fill_buffer", &StreamReaderFileObj::fill_buffer)
409
410
      .def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready)
      .def("pop_chunks", &StreamReaderFileObj::pop_chunks);
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
  py::class_<StreamReaderBytes>(m, "StreamReaderBytes", py::module_local())
      .def(py::init<
           std::string_view,
           const c10::optional<std::string>&,
           const c10::optional<OptionDict>&,
           int64_t>())
      .def("num_src_streams", &StreamReaderBytes::num_src_streams)
      .def("num_out_streams", &StreamReaderBytes::num_out_streams)
      .def("find_best_audio_stream", &StreamReaderBytes::find_best_audio_stream)
      .def("find_best_video_stream", &StreamReaderBytes::find_best_video_stream)
      .def("get_metadata", &StreamReaderBytes::get_metadata)
      .def("get_src_stream_info", &StreamReaderBytes::get_src_stream_info)
      .def("get_out_stream_info", &StreamReaderBytes::get_out_stream_info)
      .def("seek", &StreamReaderBytes::seek)
      .def("add_audio_stream", &StreamReaderBytes::add_audio_stream)
      .def("add_video_stream", &StreamReaderBytes::add_video_stream)
      .def("remove_stream", &StreamReaderBytes::remove_stream)
      .def(
          "process_packet",
          py::overload_cast<const c10::optional<double>&, const double>(
              &StreamReader::process_packet))
      .def("process_all_packets", &StreamReaderBytes::process_all_packets)
      .def("fill_buffer", &StreamReaderBytes::fill_buffer)
      .def("is_buffer_ready", &StreamReaderBytes::is_buffer_ready)
      .def("pop_chunks", &StreamReaderBytes::pop_chunks);
436
437
438
}

} // namespace
439
} // namespace torchaudio::io