conversion.h 4.21 KB
Newer Older
1
#pragma once
2
#include <libtorio/ffmpeg/ffmpeg.h>
3
4
#include <torch/types.h>

moto-meta's avatar
moto-meta committed
5
namespace torio::io {
6
7
8
9
10
11
12
13
14

////////////////////////////////////////////////////////////////////////////////
// Audio
////////////////////////////////////////////////////////////////////////////////
template <c10::ScalarType dtype, bool is_planar>
class AudioConverter {
  const int num_channels;

 public:
15
  explicit AudioConverter(int num_channels);
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

  // Converts AVFrame* into Tensor of [T, C]
  torch::Tensor convert(const AVFrame* src);

  // Converts AVFrame* into pre-allocated Tensor.
  // The shape must be [C, T] if is_planar otherwise [T, C]
  void convert(const AVFrame* src, torch::Tensor& dst);
};

////////////////////////////////////////////////////////////////////////////////
// Image
////////////////////////////////////////////////////////////////////////////////
struct ImageConverterBase {
  const int height;
  const int width;
  const int num_channels;

  ImageConverterBase(int h, int w, int c);
};

////////////////////////////////////////////////////////////////////////////////
// Interlaced Images - NHWC
////////////////////////////////////////////////////////////////////////////////
struct InterlacedImageConverter : public ImageConverterBase {
  using ImageConverterBase::ImageConverterBase;
  // convert AVFrame* into Tensor of NCHW format
  torch::Tensor convert(const AVFrame* src);
  // convert AVFrame* into pre-allocated Tensor of NHWC format
  void convert(const AVFrame* src, torch::Tensor& dst);
};

struct Interlaced16BitImageConverter : public ImageConverterBase {
  using ImageConverterBase::ImageConverterBase;
  // convert AVFrame* into Tensor of NCHW format
  torch::Tensor convert(const AVFrame* src);
  // convert AVFrame* into pre-allocated Tensor of NHWC format
  void convert(const AVFrame* src, torch::Tensor& dst);
};

////////////////////////////////////////////////////////////////////////////////
// Planar Images - NCHW
////////////////////////////////////////////////////////////////////////////////
struct PlanarImageConverter : public ImageConverterBase {
  using ImageConverterBase::ImageConverterBase;
  void convert(const AVFrame* src, torch::Tensor& dst);
  torch::Tensor convert(const AVFrame* src);
};

////////////////////////////////////////////////////////////////////////////////
// Family of YUVs - NCHW
////////////////////////////////////////////////////////////////////////////////
class YUV420PConverter : public ImageConverterBase {
 public:
  YUV420PConverter(int height, int width);
  void convert(const AVFrame* src, torch::Tensor& dst);
  torch::Tensor convert(const AVFrame* src);
};

74
75
76
77
78
79
80
class YUV420P10LEConverter : public ImageConverterBase {
 public:
  YUV420P10LEConverter(int height, int width);
  void convert(const AVFrame* src, torch::Tensor& dst);
  torch::Tensor convert(const AVFrame* src);
};

81
82
83
84
85
86
87
88
89
class NV12Converter : public ImageConverterBase {
 public:
  NV12Converter(int height, int width);
  void convert(const AVFrame* src, torch::Tensor& dst);
  torch::Tensor convert(const AVFrame* src);
};

#ifdef USE_CUDA

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// Note:
// GPU decoders are tricky. They allow to change the resolution as part of
// decoder option, and the resulting resolution is (seemingly) not retrievable.
// Therefore, we adopt delayed frame size initialization.
// For that purpose, we do not inherit from ImageConverterBase.
struct CudaImageConverterBase {
  const torch::Device device;
  bool init = false;
  int height = -1;
  int width = -1;
  explicit CudaImageConverterBase(const torch::Device& device);
};

class NV12CudaConverter : CudaImageConverterBase {
  torch::Tensor tmp_uv{};
105
106

 public:
107
  explicit NV12CudaConverter(const torch::Device& device);
108
109
110
111
  void convert(const AVFrame* src, torch::Tensor& dst);
  torch::Tensor convert(const AVFrame* src);
};

112
113
class P010CudaConverter : CudaImageConverterBase {
  torch::Tensor tmp_uv{};
114
115

 public:
116
  explicit P010CudaConverter(const torch::Device& device);
117
118
119
120
  void convert(const AVFrame* src, torch::Tensor& dst);
  torch::Tensor convert(const AVFrame* src);
};

121
class YUV444PCudaConverter : CudaImageConverterBase {
moto's avatar
moto committed
122
 public:
123
  explicit YUV444PCudaConverter(const torch::Device& device);
moto's avatar
moto committed
124
125
126
127
  void convert(const AVFrame* src, torch::Tensor& dst);
  torch::Tensor convert(const AVFrame* src);
};

128
#endif
moto-meta's avatar
moto-meta committed
129
} // namespace torio::io