// Copyright (c) 2018-2022 NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cc/symmetries.h"

namespace minigo {
namespace symmetry {

namespace {
// Generated by SymmetryTest.ConcatTable.
// clang-format off
constexpr Symmetry kConcatTable[kNumSymmetries][kNumSymmetries] = {
  {kIdentity, kRot90, kRot180, kRot270, kFlip, kFlipRot90, kFlipRot180, kFlipRot270},
  {kRot90, kRot180, kRot270, kIdentity, kFlipRot270, kFlip, kFlipRot90, kFlipRot180},
  {kRot180, kRot270, kIdentity, kRot90, kFlipRot180, kFlipRot270, kFlip, kFlipRot90},
  {kRot270, kIdentity, kRot90, kRot180, kFlipRot90, kFlipRot180, kFlipRot270, kFlip},
  {kFlip, kFlipRot90, kFlipRot180, kFlipRot270, kIdentity, kRot90, kRot180, kRot270},
  {kFlipRot90, kFlipRot180, kFlipRot270, kFlip, kRot270, kIdentity, kRot90, kRot180},
  {kFlipRot180, kFlipRot270, kFlip, kFlipRot90, kRot180, kRot270, kIdentity, kRot90},
  {kFlipRot270, kFlip, kFlipRot90, kFlipRot180, kRot90, kRot180, kRot270, kIdentity},
};
// clang-format on

constexpr const char* const kNames[kNumSymmetries] = {
    "kIdentity", "kRot90",     "kRot180",     "kRot270",
    "kFlip",     "kFlipRot90", "kFlipRot180", "kFlipRot270",
};
}  // namespace

std::ostream& operator<<(std::ostream& os, Symmetry sym) {
  if (sym < 0 || sym >= kNumSymmetries) {
    return os << "<" << static_cast<int>(sym) << ">";
  } else {
    return os << kNames[sym];
  }
}

const std::array<Symmetry, kNumSymmetries> kAllSymmetries = {
    kIdentity, kRot90,     kRot180,     kRot270,
    kFlip,     kFlipRot90, kFlipRot180, kFlipRot270,
};

const std::array<Symmetry, kNumSymmetries> kInverseSymmetries = {
    kIdentity, kRot270,    kRot180,     kRot90,
    kFlip,     kFlipRot90, kFlipRot180, kFlipRot270,
};

const std::array<std::array<Coord, kNumMoves>, kNumSymmetries> kCoords = []() {
  std::array<std::array<Coord, kNumMoves>, kNumSymmetries> result;
  std::array<Coord, kNumMoves> original;
  for (int i = 0; i < kNumMoves; ++i) {
    original[i] = i;
  }
  for (auto c : kAllSymmetries) {
    ApplySymmetry<kN, 1>(c, original.data(), result[c].data());
    result[c][Coord::kPass] = original[Coord::kPass];
  }
  return result;
}();

Coord ApplySymmetry(Symmetry sym, Coord c) {
  if (c >= Coord::kPass) {
    return c;
  }

  int row = c / kN;
  int col = c % kN;
  switch (sym) {
    case kIdentity:
      return c;
    case kRot90:
      return Coord(kN - 1 - col, row);
    case kRot180:
      return Coord(kN - 1 - row, kN - 1 - col);
    case kRot270:
      return Coord(col, kN - 1 - row);
    case kFlip:
      return Coord(col, row);
    case kFlipRot90:
      return Coord(kN - 1 - row, col);
    case kFlipRot180:
      return Coord(kN - 1 - col, kN - 1 - row);
    case kFlipRot270:
      return Coord(row, kN - 1 - col);
    default:
      MG_LOG(FATAL) << static_cast<int>(sym);
      return Coord::kInvalid;
  }
}

Symmetry Concat(Symmetry a, Symmetry b) { return kConcatTable[a][b]; }

}  // namespace symmetry
}  // namespace minigo
