video_reader.cpp 21.7 KB
Newer Older
1
2
#include "video_reader.h"

3
4
#include <Python.h>

5
6
#include "../decoder/memory_buffer.h"
#include "../decoder/sync_decoder.h"
7
8
9
10
11
12
13
14
15
16

// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
#ifdef _WIN32
PyMODINIT_FUNC PyInit_video_reader(void) {
  // No need to do anything.
  return NULL;
}
#endif

17
18
19
using namespace ffmpeg;

namespace vision {
20
21
namespace video_reader {

22
23
namespace {

24
25
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;
26
const AVRational timeBaseQ = AVRational{1, AV_TIME_BASE};
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
const size_t decoderTimeoutMs = 600000;
// A jitter can be added to the end of the range to avoid conversion/rounding
// error, small value 100us won't be enough to select the next frame, but enough
// to compensate rounding error due to the multiple conversions.
const size_t timeBaseJitterUs = 100;

DecoderParameters getDecoderParams(
    int64_t videoStartUs,
    int64_t videoEndUs,
    double seekFrameMarginUs,
    int64_t getPtsOnly,
    int64_t readVideoStream,
    int videoWidth,
    int videoHeight,
    int videoMinDimension,
    int videoMaxDimension,
    int64_t readAudioStream,
    int audioSamples,
    int audioChannels) {
  DecoderParameters params;
  params.headerOnly = getPtsOnly != 0;
  params.seekAccuracy = seekFrameMarginUs;
  params.startOffset = videoStartUs;
  params.endOffset = videoEndUs;
  params.timeoutMs = decoderTimeoutMs;
  params.preventStaleness = false;
53

54
55
56
57
58
59
60
61
62
  if (readVideoStream == 1) {
    MediaFormat videoFormat(0);
    videoFormat.type = TYPE_VIDEO;
    videoFormat.format.video.format = defaultVideoPixelFormat;
    videoFormat.format.video.width = videoWidth;
    videoFormat.format.video.height = videoHeight;
    videoFormat.format.video.minDimension = videoMinDimension;
    videoFormat.format.video.maxDimension = videoMaxDimension;
    params.formats.insert(videoFormat);
63
64
  }

65
66
67
68
69
70
71
72
  if (readAudioStream == 1) {
    MediaFormat audioFormat;
    audioFormat.type = TYPE_AUDIO;
    audioFormat.format.audio.format = defaultAudioSampleFormat;
    audioFormat.format.audio.samples = audioSamples;
    audioFormat.format.audio.channels = audioChannels;
    params.formats.insert(audioFormat);
  }
73

74
75
  return params;
}
76

77
78
79
80
81
82
83
84
85
86
87
88
89
// returns number of written bytes
template <typename T>
size_t fillTensor(
    std::vector<DecoderOutputMessage>& msgs,
    torch::Tensor& frame,
    torch::Tensor& framePts,
    int64_t num,
    int64_t den) {
  if (msgs.empty()) {
    return 0;
  }
  T* frameData = frame.numel() > 0 ? frame.data_ptr<T>() : nullptr;
  int64_t* framePtsData = framePts.data_ptr<int64_t>();
90
  CHECK_EQ(framePts.size(0), (int64_t)msgs.size());
91
92
93
94
95
96
  size_t avgElementsInFrame = frame.numel() / msgs.size();

  size_t offset = 0;
  for (size_t i = 0; i < msgs.size(); ++i) {
    const auto& msg = msgs[i];
    // convert pts into original time_base
97
98
    AVRational avr = AVRational{(int)num, (int)den};
    framePtsData[i] = av_rescale_q(msg.header.pts, timeBaseQ, avr);
99
100
101
102
103
104
105
106
107
108
109
110
111
    VLOG(2) << "PTS type: " << sizeof(T) << ", us: " << msg.header.pts
            << ", original: " << framePtsData[i];

    if (frameData) {
      auto sizeInBytes = msg.payload->length();
      memcpy(frameData + offset, msg.payload->data(), sizeInBytes);
      if (sizeof(T) == sizeof(uint8_t)) {
        // Video - move by allocated frame size
        offset += avgElementsInFrame / sizeof(T);
      } else {
        // Audio - move by number of samples
        offset += sizeInBytes / sizeof(T);
      }
112
113
    }
  }
114
  return offset * sizeof(T);
115
116
}

117
118
119
120
121
122
123
size_t fillVideoTensor(
    std::vector<DecoderOutputMessage>& msgs,
    torch::Tensor& videoFrame,
    torch::Tensor& videoFramePts,
    int64_t num,
    int64_t den) {
  return fillTensor<uint8_t>(msgs, videoFrame, videoFramePts, num, den);
124
125
}

126
127
size_t fillAudioTensor(
    std::vector<DecoderOutputMessage>& msgs,
128
    torch::Tensor& audioFrame,
129
130
131
132
    torch::Tensor& audioFramePts,
    int64_t num,
    int64_t den) {
  return fillTensor<float>(msgs, audioFrame, audioFramePts, num, den);
133
134
}

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
void offsetsToUs(
    double& seekFrameMargin,
    int64_t readVideoStream,
    int64_t videoStartPts,
    int64_t videoEndPts,
    int64_t videoTimeBaseNum,
    int64_t videoTimeBaseDen,
    int64_t readAudioStream,
    int64_t audioStartPts,
    int64_t audioEndPts,
    int64_t audioTimeBaseNum,
    int64_t audioTimeBaseDen,
    int64_t& videoStartUs,
    int64_t& videoEndUs) {
  seekFrameMargin *= AV_TIME_BASE;
  videoStartUs = 0;
  videoEndUs = -1;

  if (readVideoStream) {
154
    AVRational vr = AVRational{(int)videoTimeBaseNum, (int)videoTimeBaseDen};
155
    if (videoStartPts > 0) {
156
      videoStartUs = av_rescale_q(videoStartPts, vr, timeBaseQ);
157
158
159
160
161
    }
    if (videoEndPts > 0) {
      // Add jitter to the end of the range to avoid conversion/rounding error.
      // Small value 100us won't be enough to select the next frame, but enough
      // to compensate rounding error due to the multiple conversions.
162
      videoEndUs = timeBaseJitterUs + av_rescale_q(videoEndPts, vr, timeBaseQ);
163
164
    }
  } else if (readAudioStream) {
165
    AVRational ar = AVRational{(int)audioTimeBaseNum, (int)audioTimeBaseDen};
166
    if (audioStartPts > 0) {
167
      videoStartUs = av_rescale_q(audioStartPts, ar, timeBaseQ);
168
169
170
171
172
    }
    if (audioEndPts > 0) {
      // Add jitter to the end of the range to avoid conversion/rounding error.
      // Small value 100us won't be enough to select the next frame, but enough
      // to compensate rounding error due to the multiple conversions.
173
      videoEndUs = timeBaseJitterUs + av_rescale_q(audioEndPts, ar, timeBaseQ);
174
    }
175
176
177
178
179
180
181
182
183
184
185
186
187
  }
}

torch::List<torch::Tensor> readVideo(
    bool isReadFile,
    const torch::Tensor& input_video,
    std::string videoPath,
    double seekFrameMargin,
    int64_t getPtsOnly,
    int64_t readVideoStream,
    int64_t width,
    int64_t height,
    int64_t minDimension,
188
    int64_t maxDimension,
189
190
191
192
193
194
195
196
197
198
199
    int64_t videoStartPts,
    int64_t videoEndPts,
    int64_t videoTimeBaseNum,
    int64_t videoTimeBaseDen,
    int64_t readAudioStream,
    int64_t audioSamples,
    int64_t audioChannels,
    int64_t audioStartPts,
    int64_t audioEndPts,
    int64_t audioTimeBaseNum,
    int64_t audioTimeBaseDen) {
200
201
202
  int64_t videoStartUs, videoEndUs;

  offsetsToUs(
203
204
205
206
207
208
209
210
211
212
      seekFrameMargin,
      readVideoStream,
      videoStartPts,
      videoEndPts,
      videoTimeBaseNum,
      videoTimeBaseDen,
      readAudioStream,
      audioStartPts,
      audioEndPts,
      audioTimeBaseNum,
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
      audioTimeBaseDen,
      videoStartUs,
      videoEndUs);

  DecoderParameters params = getDecoderParams(
      videoStartUs, // videoStartPts
      videoEndUs, // videoEndPts
      seekFrameMargin, // seekFrameMargin
      getPtsOnly, // getPtsOnly
      readVideoStream, // readVideoStream
      width, // width
      height, // height
      minDimension, // minDimension
      maxDimension, // maxDimension
      readAudioStream, // readAudioStream
      audioSamples, // audioSamples
      audioChannels // audioChannels
  );
231

232
233
234
235
  SyncDecoder decoder;
  std::vector<DecoderOutputMessage> audioMessages, videoMessages;
  DecoderInCallback callback = nullptr;
  std::string logMessage, logType;
236
  if (isReadFile) {
237
238
239
    params.uri = videoPath;
    logType = "file";
    logMessage = videoPath;
240
  } else {
241
242
243
244
    callback = MemoryBuffer::getCallback(
        input_video.data_ptr<uint8_t>(), input_video.size(0));
    logType = "memory";
    logMessage = std::to_string(input_video.size(0));
245
246
  }

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
  VLOG(1) << "Video decoding from " << logType << " [" << logMessage
          << "] has started";

  const auto now = std::chrono::system_clock::now();

  bool succeeded;
  DecoderMetadata audioMetadata, videoMetadata;
  std::vector<DecoderMetadata> metadata;
  if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
    for (const auto& header : metadata) {
      if (header.format.type == TYPE_VIDEO) {
        videoMetadata = header;
      } else if (header.format.type == TYPE_AUDIO) {
        audioMetadata = header;
      }
    }
    int res;
    DecoderOutputMessage msg;
    while (0 == (res = decoder.decode(&msg, decoderTimeoutMs))) {
      if (msg.header.format.type == TYPE_VIDEO) {
        videoMessages.push_back(std::move(msg));
      }
      if (msg.header.format.type == TYPE_AUDIO) {
        audioMessages.push_back(std::move(msg));
      }
      msg.payload.reset();
    }
  } else {
    LOG(ERROR) << "Decoder initialization has failed";
  }
  const auto then = std::chrono::system_clock::now();
  VLOG(1) << "Video decoding from " << logType << " [" << logMessage
          << "] has finished, "
          << std::chrono::duration_cast<std::chrono::microseconds>(then - now)
                 .count()
          << " us";

  decoder.shutdown();

286
287
288
289
290
  // video section
  torch::Tensor videoFrame = torch::zeros({0}, torch::kByte);
  torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong);
  torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
  torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
291
292
  torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);

293
294
295
296
297
298
299
300
  if (succeeded && readVideoStream == 1) {
    if (!videoMessages.empty()) {
      const auto& header = videoMetadata;
      const auto& format = header.format.format.video;
      int numVideoFrames = videoMessages.size();
      int outHeight = format.height;
      int outWidth = format.width;
      int numChannels = 3; // decoder guarantees the default AV_PIX_FMT_RGB24
301

302
      size_t expectedWrittenBytes = 0;
303
304
305
      if (getPtsOnly == 0) {
        videoFrame = torch::zeros(
            {numVideoFrames, outHeight, outWidth, numChannels}, torch::kByte);
306
        expectedWrittenBytes =
307
            (size_t)numVideoFrames * outHeight * outWidth * numChannels;
308
309
310
311
      }

      videoFramePts = torch::zeros({numVideoFrames}, torch::kLong);

312
313
314
315
316
317
318
319
      VLOG(2) << "video duration: " << header.duration
              << ", fps: " << header.fps << ", num: " << header.num
              << ", den: " << header.den << ", num frames: " << numVideoFrames;

      auto numberWrittenBytes = fillVideoTensor(
          videoMessages, videoFrame, videoFramePts, header.num, header.den);

      CHECK_EQ(numberWrittenBytes, expectedWrittenBytes);
320
321
322

      videoTimeBase = torch::zeros({2}, torch::kInt);
      int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
323
324
      videoTimeBaseData[0] = header.num;
      videoTimeBaseData[1] = header.den;
325
326
327

      videoFps = torch::zeros({1}, torch::kFloat);
      float* videoFpsData = videoFps.data_ptr<float>();
328
      videoFpsData[0] = header.fps;
329
330
331

      videoDuration = torch::zeros({1}, torch::kLong);
      int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
332
333
      AVRational vr = AVRational{(int)header.num, (int)header.den};
      videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, vr);
334
335
      VLOG(1) << "Video decoding from " << logType << " [" << logMessage
              << "] filled video tensors";
336
337
338
339
340
341
342
343
344
345
    } else {
      VLOG(1) << "Miss video stream";
    }
  }

  // audio section
  torch::Tensor audioFrame = torch::zeros({0}, torch::kFloat);
  torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong);
  torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
  torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
346
  torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
347
348
349
350
351
352
353
354
  if (succeeded && readAudioStream == 1) {
    if (!audioMessages.empty()) {
      const auto& header = audioMetadata;
      const auto& format = header.format.format.audio;

      int64_t outAudioChannels = format.channels;
      int bytesPerSample =
          av_get_bytes_per_sample(static_cast<AVSampleFormat>(format.format));
355

356
357
      int numAudioFrames = audioMessages.size();
      int64_t numAudioSamples = 0;
358
      if (getPtsOnly == 0) {
359
360
361
362
363
364
365
366
        int64_t frameSizeTotal = 0;
        for (auto const& audioMessage : audioMessages) {
          frameSizeTotal += audioMessage.payload->length();
        }

        CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0);
        numAudioSamples = frameSizeTotal / (outAudioChannels * bytesPerSample);

367
368
369
370
        audioFrame =
            torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat);
      }
      audioFramePts = torch::zeros({numAudioFrames}, torch::kLong);
371
372
373
374
375
376
377
378
379
380
381

      VLOG(2) << "audio duration: " << header.duration
              << ", channels: " << format.channels
              << ", sample rate: " << format.samples << ", num: " << header.num
              << ", den: " << header.den;

      auto numberWrittenBytes = fillAudioTensor(
          audioMessages, audioFrame, audioFramePts, header.num, header.den);
      CHECK_EQ(
          numberWrittenBytes,
          numAudioSamples * outAudioChannels * sizeof(float));
382
383
384

      audioTimeBase = torch::zeros({2}, torch::kInt);
      int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
385
386
      audioTimeBaseData[0] = header.num;
      audioTimeBaseData[1] = header.den;
387
388
389

      audioSampleRate = torch::zeros({1}, torch::kInt);
      int* audioSampleRateData = audioSampleRate.data_ptr<int>();
390
      audioSampleRateData[0] = format.samples;
391
392
393

      audioDuration = torch::zeros({1}, torch::kLong);
      int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
394
395
      AVRational ar = AVRational{(int)header.num, (int)header.den};
      audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, ar);
396
397
      VLOG(1) << "Video decoding from " << logType << " [" << logMessage
              << "] filled audio tensors";
398
399
400
401
402
403
404
405
406
407
    } else {
      VLOG(1) << "Miss audio stream";
    }
  }

  torch::List<torch::Tensor> result;
  result.push_back(std::move(videoFrame));
  result.push_back(std::move(videoFramePts));
  result.push_back(std::move(videoTimeBase));
  result.push_back(std::move(videoFps));
408
  result.push_back(std::move(videoDuration));
409
410
411
412
  result.push_back(std::move(audioFrame));
  result.push_back(std::move(audioFramePts));
  result.push_back(std::move(audioTimeBase));
  result.push_back(std::move(audioSampleRate));
413
  result.push_back(std::move(audioDuration));
414

415
416
417
  VLOG(1) << "Video decoding from " << logType << " [" << logMessage
          << "] about to return";

418
419
420
  return result;
}

421
422
423
424
torch::List<torch::Tensor> probeVideo(
    bool isReadFile,
    const torch::Tensor& input_video,
    std::string videoPath) {
425
426
427
  DecoderParameters params = getDecoderParams(
      0, // videoStartUs
      -1, // videoEndUs
428
      0, // seekFrameMargin
429
      1, // getPtsOnly
430
431
432
433
      1, // readVideoStream
      0, // width
      0, // height
      0, // minDimension
434
      0, // maxDimension
435
436
      1, // readAudioStream
      0, // audioSamples
437
      0 // audioChannels
438
439
  );

440
441
442
  SyncDecoder decoder;
  DecoderInCallback callback = nullptr;
  std::string logMessage, logType;
443
  if (isReadFile) {
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
    params.uri = videoPath;
    logType = "file";
    logMessage = videoPath;
  } else {
    callback = MemoryBuffer::getCallback(
        input_video.data_ptr<uint8_t>(), input_video.size(0));
    logType = "memory";
    logMessage = std::to_string(input_video.size(0));
  }

  VLOG(1) << "Video probing from " << logType << " [" << logMessage
          << "] has started";

  const auto now = std::chrono::system_clock::now();

  bool succeeded;
  bool gotAudio = false, gotVideo = false;
  DecoderMetadata audioMetadata, videoMetadata;
  std::vector<DecoderMetadata> metadata;
  if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
    for (const auto& header : metadata) {
      if (header.format.type == TYPE_VIDEO) {
        gotVideo = true;
        videoMetadata = header;
      } else if (header.format.type == TYPE_AUDIO) {
        gotAudio = true;
        audioMetadata = header;
      }
    }
    const auto then = std::chrono::system_clock::now();
    VLOG(1) << "Video probing from " << logType << " [" << logMessage
            << "] has finished, "
            << std::chrono::duration_cast<std::chrono::microseconds>(then - now)
                   .count()
            << " us";
479
  } else {
480
    LOG(ERROR) << "Decoder initialization has failed";
481
  }
482
483
484

  decoder.shutdown();

485
486
487
488
489
  // video section
  torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
  torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
  torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);

490
  if (succeeded && gotVideo) {
491
492
    videoTimeBase = torch::zeros({2}, torch::kInt);
    int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
493
494
495
496
    const auto& header = videoMetadata;

    videoTimeBaseData[0] = header.num;
    videoTimeBaseData[1] = header.den;
497
498
499

    videoFps = torch::zeros({1}, torch::kFloat);
    float* videoFpsData = videoFps.data_ptr<float>();
500
    videoFpsData[0] = header.fps;
501
502
503

    videoDuration = torch::zeros({1}, torch::kLong);
    int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
504
505
    AVRational avr = AVRational{(int)header.num, (int)header.den};
    videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);
506
507
508
509
510
511

    VLOG(2) << "Prob fps: " << header.fps << ", duration: " << header.duration
            << ", num: " << header.num << ", den: " << header.den;

    VLOG(1) << "Video probing from " << logType << " [" << logMessage
            << "] filled video tensors";
512
  } else {
513
    LOG(ERROR) << "Miss video stream";
514
515
516
517
518
519
520
  }

  // audio section
  torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
  torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
  torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);

521
  if (succeeded && gotAudio) {
522
523
    audioTimeBase = torch::zeros({2}, torch::kInt);
    int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
524
525
526
527
528
529
    const auto& header = audioMetadata;
    const auto& media = header.format;
    const auto& format = media.format.audio;

    audioTimeBaseData[0] = header.num;
    audioTimeBaseData[1] = header.den;
530
531
532

    audioSampleRate = torch::zeros({1}, torch::kInt);
    int* audioSampleRateData = audioSampleRate.data_ptr<int>();
533
    audioSampleRateData[0] = format.samples;
534
535
536

    audioDuration = torch::zeros({1}, torch::kLong);
    int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
537
538
    AVRational avr = AVRational{(int)header.num, (int)header.den};
    audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);
539
540
541
542
543
544
545

    VLOG(2) << "Prob sample rate: " << format.samples
            << ", duration: " << header.duration << ", num: " << header.num
            << ", den: " << header.den;

    VLOG(1) << "Video probing from " << logType << " [" << logMessage
            << "] filled audio tensors";
546
547
548
549
550
551
552
553
554
555
556
557
  } else {
    VLOG(1) << "Miss audio stream";
  }

  torch::List<torch::Tensor> result;
  result.push_back(std::move(videoTimeBase));
  result.push_back(std::move(videoFps));
  result.push_back(std::move(videoDuration));
  result.push_back(std::move(audioTimeBase));
  result.push_back(std::move(audioSampleRate));
  result.push_back(std::move(audioDuration));

558
559
560
  VLOG(1) << "Video probing from " << logType << " [" << logMessage
          << "] is about to return";

561
562
563
  return result;
}

564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
} // namespace

torch::List<torch::Tensor> read_video_from_memory(
    torch::Tensor input_video,
    double seekFrameMargin,
    int64_t getPtsOnly,
    int64_t readVideoStream,
    int64_t width,
    int64_t height,
    int64_t minDimension,
    int64_t maxDimension,
    int64_t videoStartPts,
    int64_t videoEndPts,
    int64_t videoTimeBaseNum,
    int64_t videoTimeBaseDen,
    int64_t readAudioStream,
    int64_t audioSamples,
    int64_t audioChannels,
    int64_t audioStartPts,
    int64_t audioEndPts,
    int64_t audioTimeBaseNum,
    int64_t audioTimeBaseDen) {
Kai Zhang's avatar
Kai Zhang committed
586
587
  C10_LOG_API_USAGE_ONCE(
      "torchvision.csrc.io.video_reader.video_reader.read_video_from_memory");
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
  return readVideo(
      false,
      input_video,
      "", // videoPath
      seekFrameMargin,
      getPtsOnly,
      readVideoStream,
      width,
      height,
      minDimension,
      maxDimension,
      videoStartPts,
      videoEndPts,
      videoTimeBaseNum,
      videoTimeBaseDen,
      readAudioStream,
      audioSamples,
      audioChannels,
      audioStartPts,
      audioEndPts,
      audioTimeBaseNum,
      audioTimeBaseDen);
}

torch::List<torch::Tensor> read_video_from_file(
    std::string videoPath,
    double seekFrameMargin,
    int64_t getPtsOnly,
    int64_t readVideoStream,
    int64_t width,
    int64_t height,
    int64_t minDimension,
    int64_t maxDimension,
    int64_t videoStartPts,
    int64_t videoEndPts,
    int64_t videoTimeBaseNum,
    int64_t videoTimeBaseDen,
    int64_t readAudioStream,
    int64_t audioSamples,
    int64_t audioChannels,
    int64_t audioStartPts,
    int64_t audioEndPts,
    int64_t audioTimeBaseNum,
    int64_t audioTimeBaseDen) {
Kai Zhang's avatar
Kai Zhang committed
632
633
  C10_LOG_API_USAGE_ONCE(
      "torchvision.csrc.io.video_reader.video_reader.read_video_from_file");
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
  torch::Tensor dummy_input_video = torch::ones({0});
  return readVideo(
      true,
      dummy_input_video,
      videoPath,
      seekFrameMargin,
      getPtsOnly,
      readVideoStream,
      width,
      height,
      minDimension,
      maxDimension,
      videoStartPts,
      videoEndPts,
      videoTimeBaseNum,
      videoTimeBaseDen,
      readAudioStream,
      audioSamples,
      audioChannels,
      audioStartPts,
      audioEndPts,
      audioTimeBaseNum,
      audioTimeBaseDen);
}

torch::List<torch::Tensor> probe_video_from_memory(torch::Tensor input_video) {
Kai Zhang's avatar
Kai Zhang committed
660
661
  C10_LOG_API_USAGE_ONCE(
      "torchvision.csrc.io.video_reader.video_reader.probe_video_from_memory");
662
663
664
  return probeVideo(false, input_video, "");
}

665
torch::List<torch::Tensor> probe_video_from_file(std::string videoPath) {
Kai Zhang's avatar
Kai Zhang committed
666
667
  C10_LOG_API_USAGE_ONCE(
      "torchvision.csrc.io.video_reader.video_reader.probe_video_from_file");
668
669
670
671
  torch::Tensor dummy_input_video = torch::ones({0});
  return probeVideo(true, dummy_input_video, videoPath);
}

672
TORCH_LIBRARY_FRAGMENT(video_reader, m) {
673
674
675
676
  m.def("read_video_from_memory", read_video_from_memory);
  m.def("read_video_from_file", read_video_from_file);
  m.def("probe_video_from_memory", probe_video_from_memory);
  m.def("probe_video_from_file", probe_video_from_file);
677
}
678
679
680

} // namespace video_reader
} // namespace vision