#include "video.h"

#include <regex>

namespace vision {
namespace video {

namespace {

const size_t decoderTimeoutMs = 600000;
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;

// returns number of written bytes
template <typename T>
size_t fillTensorList(DecoderOutputMessage& msgs, torch::Tensor& frame) {
  const auto& msg = msgs;
  T* frameData = frame.numel() > 0 ? frame.data_ptr<T>() : nullptr;
  if (frameData) {
    auto sizeInBytes = msg.payload->length();
    memcpy(frameData, msg.payload->data(), sizeInBytes);
  }
  return sizeof(T);
}

size_t fillVideoTensor(DecoderOutputMessage& msgs, torch::Tensor& videoFrame) {
  return fillTensorList<uint8_t>(msgs, videoFrame);
}

size_t fillAudioTensor(DecoderOutputMessage& msgs, torch::Tensor& audioFrame) {
  return fillTensorList<float>(msgs, audioFrame);
}

std::array<std::pair<std::string, ffmpeg::MediaType>, 4>::const_iterator
_parse_type(const std::string& stream_string) {
  static const std::array<std::pair<std::string, MediaType>, 4> types = {{
      {"video", TYPE_VIDEO},
      {"audio", TYPE_AUDIO},
      {"subtitle", TYPE_SUBTITLE},
      {"cc", TYPE_CC},
  }};
  auto device = std::find_if(
      types.begin(),
      types.end(),
      [stream_string](const std::pair<std::string, MediaType>& p) {
        return p.first == stream_string;
      });
  if (device != types.end()) {
    return device;
  }
  TORCH_CHECK(
      false, "Expected one of [audio, video, subtitle, cc] ", stream_string);
}

std::string parse_type_to_string(const std::string& stream_string) {
  auto device = _parse_type(stream_string);
  return device->first;
}

MediaType parse_type_to_mt(const std::string& stream_string) {
  auto device = _parse_type(stream_string);
  return device->second;
}

std::tuple<std::string, long> _parseStream(const std::string& streamString) {
  TORCH_CHECK(!streamString.empty(), "Stream string must not be empty");
  static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
  std::smatch match;

  TORCH_CHECK(
      std::regex_match(streamString, match, regex),
      "Invalid stream string: '",
      streamString,
      "'");

  std::string type_ = "video";
  type_ = parse_type_to_string(match[1].str());
  long index_ = -1;
  if (match[2].matched) {
    try {
      index_ = c10::stoi(match[2].str());
    } catch (const std::exception&) {
      TORCH_CHECK(
          false,
          "Could not parse device index '",
          match[2].str(),
          "' in device string '",
          streamString,
          "'");
    }
  }
  return std::make_tuple(type_, index_);
}

} // namespace

void Video::_getDecoderParams(
    double videoStartS,
    int64_t getPtsOnly,
    std::string stream,
    long stream_id = -1,
    bool all_streams = false,
    double seekFrameMarginUs = 10) {
  int64_t videoStartUs = int64_t(videoStartS * 1e6);

  params.timeoutMs = decoderTimeoutMs;
  params.startOffset = videoStartUs;
  params.seekAccuracy = seekFrameMarginUs;
  params.headerOnly = false;

  params.preventStaleness = false; // not sure what this is about

  if (all_streams == true) {
    MediaFormat format;
    format.stream = -2;
    format.type = TYPE_AUDIO;
    params.formats.insert(format);

    format.type = TYPE_VIDEO;
    format.stream = -2;
    format.format.video.width = 0;
    format.format.video.height = 0;
    format.format.video.cropImage = 0;
    format.format.video.format = defaultVideoPixelFormat;
    params.formats.insert(format);

    format.type = TYPE_SUBTITLE;
    format.stream = -2;
    params.formats.insert(format);

    format.type = TYPE_CC;
    format.stream = -2;
    params.formats.insert(format);
  } else {
    // parse stream type
    MediaType stream_type = parse_type_to_mt(stream);

    // TODO: reset params.formats
    std::set<MediaFormat> formats;
    params.formats = formats;
    // Define new format
    MediaFormat format;
    format.type = stream_type;
    format.stream = stream_id;
    if (stream_type == TYPE_VIDEO) {
      format.format.video.width = 0;
      format.format.video.height = 0;
      format.format.video.cropImage = 0;
      format.format.video.format = defaultVideoPixelFormat;
    }
    params.formats.insert(format);
  }

} // _get decoder params

Video::Video(std::string videoPath, std::string stream) {
  // parse stream information
  current_stream = _parseStream(stream);
  // note that in the initial call we want to get all streams
  Video::_getDecoderParams(
      0, // video start
      0, // headerOnly
      std::get<0>(current_stream), // stream info - remove that
      long(-1), // stream_id parsed from info above change to -2
      true // read all streams
  );

  std::string logMessage, logType;

  // TODO: add read from memory option
  params.uri = videoPath;
  logType = "file";
  logMessage = videoPath;

  // locals
  std::vector<double> audioFPS, videoFPS, ccFPS, subsFPS;
  std::vector<double> audioDuration, videoDuration, ccDuration, subsDuration;
  std::vector<double> audioTB, videoTB, ccTB, subsTB;
  c10::Dict<std::string, std::vector<double>> audioMetadata;
  c10::Dict<std::string, std::vector<double>> videoMetadata;

  // calback and metadata defined in struct
  succeeded = decoder.init(params, std::move(callback), &metadata);
  if (succeeded) {
    for (const auto& header : metadata) {
      double fps = double(header.fps);
      double duration = double(header.duration) * 1e-6; // * timeBase;

      if (header.format.type == TYPE_VIDEO) {
        videoFPS.push_back(fps);
        videoDuration.push_back(duration);
      } else if (header.format.type == TYPE_AUDIO) {
        audioFPS.push_back(fps);
        audioDuration.push_back(duration);
      } else if (header.format.type == TYPE_CC) {
        ccFPS.push_back(fps);
        ccDuration.push_back(duration);
      } else if (header.format.type == TYPE_SUBTITLE) {
        subsFPS.push_back(fps);
        subsDuration.push_back(duration);
      };
    }
  }
  audioMetadata.insert("duration", audioDuration);
  audioMetadata.insert("framerate", audioFPS);
  videoMetadata.insert("duration", videoDuration);
  videoMetadata.insert("fps", videoFPS);
  streamsMetadata.insert("video", videoMetadata);
  streamsMetadata.insert("audio", audioMetadata);

  succeeded = Video::setCurrentStream(stream);
  LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
  if (std::get<1>(current_stream) != -1) {
    LOG(INFO)
        << "Stream index set to " << std::get<1>(current_stream)
        << ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
  }
} // video

bool Video::setCurrentStream(std::string stream = "video") {
  if ((!stream.empty()) && (_parseStream(stream) != current_stream)) {
    current_stream = _parseStream(stream);
  }

  double ts = 0;
  if (seekTS > 0) {
    ts = seekTS;
  }

  _getDecoderParams(
      ts, // video start
      0, // headerOnly
      std::get<0>(current_stream), // stream
      long(std::get<1>(
          current_stream)), // stream_id parsed from info above change to -2
      false // read all streams
  );

  // calback and metadata defined in Video.h
  return (decoder.init(params, std::move(callback), &metadata));
}

std::tuple<std::string, int64_t> Video::getCurrentStream() const {
  return current_stream;
}

c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
    getStreamMetadata() const {
  return streamsMetadata;
}

void Video::Seek(double ts) {
  // initialize the class variables used for seeking and retrurn
  _getDecoderParams(
      ts, // video start
      0, // headerOnly
      std::get<0>(current_stream), // stream
      long(std::get<1>(
          current_stream)), // stream_id parsed from info above change to -2
      false // read all streams
  );

  // calback and metadata defined in Video.h
  succeeded = decoder.init(params, std::move(callback), &metadata);
  LOG(INFO) << "Decoder init at seek " << succeeded << "\n";
}

std::tuple<torch::Tensor, double> Video::Next() {
  // if failing to decode simply return a null tensor (note, should we
  // raise an exeption?)
  double frame_pts_s;
  torch::Tensor outFrame = torch::zeros({0}, torch::kByte);

  // decode single frame
  DecoderOutputMessage out;
  int64_t res = decoder.decode(&out, decoderTimeoutMs);
  // if successfull
  if (res == 0) {
    frame_pts_s = double(double(out.header.pts) * 1e-6);

    auto header = out.header;
    const auto& format = header.format;

    // initialize the output variables based on type

    if (format.type == TYPE_VIDEO) {
      // note: this can potentially be optimized
      // by having the global tensor that we fill at decode time
      // (would avoid allocations)
      int outHeight = format.format.video.height;
      int outWidth = format.format.video.width;
      int numChannels = 3;
      outFrame = torch::zeros({outHeight, outWidth, numChannels}, torch::kByte);
      auto numberWrittenBytes = fillVideoTensor(out, outFrame);
      outFrame = outFrame.permute({2, 0, 1});

    } else if (format.type == TYPE_AUDIO) {
      int outAudioChannels = format.format.audio.channels;
      int bytesPerSample = av_get_bytes_per_sample(
          static_cast<AVSampleFormat>(format.format.audio.format));
      int frameSizeTotal = out.payload->length();

      CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0);
      int numAudioSamples =
          frameSizeTotal / (outAudioChannels * bytesPerSample);

      outFrame =
          torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat);

      auto numberWrittenBytes = fillAudioTensor(out, outFrame);
    }
    // currently not supporting other formats (will do soon)

    out.payload.reset();
  } else if (res == ENODATA) {
    LOG(INFO) << "Decoder ran out of frames (ENODATA)\n";
  } else {
    LOG(ERROR) << "Decoder failed with ERROR_CODE " << res;
  }

  return std::make_tuple(outFrame, frame_pts_s);
}

static auto registerVideo =
    torch::class_<Video>("torchvision", "Video")
        .def(torch::init<std::string, std::string>())
        .def("get_current_stream", &Video::getCurrentStream)
        .def("set_current_stream", &Video::setCurrentStream)
        .def("get_metadata", &Video::getStreamMetadata)
        .def("seek", &Video::Seek)
        .def("next", &Video::Next);

} // namespace video
} // namespace vision
