"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "881a6b58c3b5594d7f2ca1150b5a6779dceee808"
gpu_decoder.cpp 2.02 KB
Newer Older
Prabhat Roy's avatar
Prabhat Roy committed
1
2
3
4
5
#include "gpu_decoder.h"
#include <c10/cuda/CUDAGuard.h>

/* Set cuda device, create cuda context and initialise the demuxer and decoder.
 */
6
7
8
9
GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev)
    : demuxer(src_file.c_str()) {
  at::cuda::CUDAGuard device_guard(dev);
  device = device_guard.current_device().index();
Prabhat Roy's avatar
Prabhat Roy committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  check_for_cuda_errors(
      cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__);
  decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec()));
  initialised = true;
}

GPUDecoder::~GPUDecoder() {
  at::cuda::CUDAGuard device_guard(device);
  decoder.release();
  if (initialised) {
    check_for_cuda_errors(
        cuDevicePrimaryCtxRelease(device), __LINE__, __FILE__);
  }
}

/* Fetch a decoded frame tensor after demuxing and decoding.
 */
torch::Tensor GPUDecoder::decode() {
  torch::Tensor frameTensor;
  unsigned long videoBytes = 0;
  uint8_t* video = nullptr;
  at::cuda::CUDAGuard device_guard(device);
32
  torch::Tensor frame;
Prabhat Roy's avatar
Prabhat Roy committed
33
34
35
36
37
38
39
40
  do {
    demuxer.demux(&video, &videoBytes);
    decoder.decode(video, videoBytes);
    frame = decoder.fetch_frame();
  } while (frame.numel() == 0 && videoBytes > 0);
  return frame;
}

Prabhat Roy's avatar
Prabhat Roy committed
41
42
43
44
45
46
47
48
/* Seek to a passed timestamp. The second argument controls whether to seek to a
 * keyframe.
 */
void GPUDecoder::seek(double timestamp, bool keyframes_only) {
  int flag = keyframes_only ? 0 : AVSEEK_FLAG_ANY;
  demuxer.seek(timestamp, flag);
}

49
50
51
52
53
54
55
56
57
58
c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
    get_metadata() const {
  c10::Dict<std::string, c10::Dict<std::string, double>> metadata;
  c10::Dict<std::string, double> video_metadata;
  video_metadata.insert("duration", demuxer.get_duration());
  video_metadata.insert("fps", demuxer.get_fps());
  metadata.insert("video", video_metadata);
  return metadata;
}

Prabhat Roy's avatar
Prabhat Roy committed
59
60
TORCH_LIBRARY(torchvision, m) {
  m.class_<GPUDecoder>("GPUDecoder")
61
      .def(torch::init<std::string, torch::Device>())
Prabhat Roy's avatar
Prabhat Roy committed
62
      .def("seek", &GPUDecoder::seek)
63
      .def("get_metadata", &GPUDecoder::get_metadata)
64
      .def("next", &GPUDecoder::decode);
Prabhat Roy's avatar
Prabhat Roy committed
65
}