/*! \file
    \brief Template for a singlestage threadblock-scoped Blocked-Ell MMA.
*/

#pragma once
#include "hytlass/hytlass.h"
#include "hytlass/array.h"
#include "hytlass/aligned_buffer.h"
#include "hytlass/numeric_conversion.h"
#include "hytlass/numeric_types.h"
#include "hytlass/matrix_shape.h"
#include "hytlass/gemm/gemm.h"
#include "hytlass/gemm/threadblock/mma_base.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace hytlass {
namespace gemm {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product targeting HIP cores and SIMT math instructions.
template <
  /// Size of the Gemm problem - concept: gemm::GemmShape<>
  typename Shape_,
  /// Iterates over tiles of A operand in global memory 
  //  (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
  typename IteratorA_,
  /// Iterates over tiles of A operand in shared memory
  /// (concept: WriteableTileIterator | RandomAccessTileIterator)
  typename SmemIteratorA_,
  /// Iterates over tiles of B operand in global memory
  //  (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
  typename IteratorB_,
  /// Iterates over tiles of B operand in shared memory
  /// (concept: WriteableTileIterator | RandomAccessTileIterator)
  typename SmemIteratorB_,
  /// Data type of accumulator matrix
  typename ElementC_,
  /// Data type of accumulator matrix
  typename LayoutC_,
  /// Policy describing tuning details (concept: MmaPolicy)
  typename Policy_,
  /// Transformation applied to A operand
  typename TransformA_ = NumericArrayConverter<
    typename SmemIteratorA_::Element, 
    typename IteratorA_::Element, 
    IteratorA_::Fragment::kElements>,
  /// Transformation applied to B operand
  typename TransformB_ = NumericArrayConverter<
    typename SmemIteratorB_::Element, 
    typename IteratorB_::Element, 
    IteratorB_::Fragment::kElements>,
  /// Used for partial specialization
  typename Enable = bool
>
class EllMmaSingleStage : public MmaBase<Shape_, Policy_, 1> {
public:

  ///< Base class
  using Base = MmaBase<Shape_, Policy_, 1>;

  using Shape = Shape_;             ///< Size of the Gemm problem - concept: gemm::GemmShape<>
  using IteratorA = IteratorA_;     ///< Iterates over tiles of A operand in global memory
  using IteratorB = IteratorB_;     ///< Iterates over tiles of B operand in global memory
  using ElementC = ElementC_;       ///< Data type of accumulator matrix
  using LayoutC = LayoutC_;         ///< Layout of accumulator matrix
  using Policy = Policy_;           ///< Policy describing tuning details

  using SmemIteratorA = SmemIteratorA_;
  using SmemIteratorB = SmemIteratorB_;

  using TransformA = TransformA_;
  using TransformB = TransformB_;

  //
  // Dependent types
  //

  /// Fragment of operand A loaded from global memory
  using FragmentA = typename IteratorA::Fragment;

  /// Fragment of operand B loaded from global memory
  using FragmentB = typename IteratorB::Fragment;

  /// Fragment of accumulator tile
  using FragmentC = typename Policy::Operator::FragmentC;

  /// Warp-level Mma
  using Operator = typename Policy::Operator;

  /// Obtain the arch tag from the warp-level operator
  using ArchTag = typename Policy::Operator::ArchTag;

  /// Complex transform on A operand
  static ComplexTransform const kTransformA = ComplexTransform::kNone;

  /// Complex transform on B operand
  static ComplexTransform const kTransformB = ComplexTransform::kNone;

  // staticaly assert kStages for EllSingleStage is one
  static_assert((Base::kStages == 1), "EllSingleStage requires kStages set to value 1");

private:

  using WarpFragmentA = typename Operator::FragmentA;
  using WarpFragmentB = typename Operator::FragmentB;

protected:

  /// Iterator to write threadblock-scoped tile of A operand to shared memory
  SmemIteratorA smem_iterator_A_;

  /// Iterator to write threadblock-scoped tile of B operand to shared memory
  SmemIteratorB smem_iterator_B_;

  using EllIterator = typename hytlass::transform::threadblock::ell::Iterator;

public:
  /// Construct from tensor references
  HYTLASS_DEVICE
  EllMmaSingleStage(
    typename Base::SharedStorage &shared_storage,       ///< Shared storage needed for internal use by threadblock-scoped GEMM
    int thread_idx,                                     ///< ID within the threadblock
    int warp_idx,                                       ///< ID of warp
    int lane_idx                                        ///< ID of each thread within a warp
  ):
    Base(shared_storage, thread_idx, warp_idx, lane_idx),
    smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
    smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {

    // Compute warp location within threadblock tile by mapping the warp_id to
    // three coordinates:
    //   _m: the warp's position within the threadblock along the M dimension
    //   _n: the warp's position within the threadblock along the N dimension
    //   _k: the warp's position within the threadblock along the K dimension

    int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
    int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);

    int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
    int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;

    // Add per-warp offsets in units of warp-level tiles
    this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
    this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
    
  }

  /// Perform a threadblock-scoped matrix multiply-accumulate
  template<bool is_A_sparse, bool is_offset_constant>
  HYTLASS_DEVICE
  void operator()(
    int gemm_k_iterations,                            ///< number of iterations of the mainloop
    FragmentC &accum,                                 ///< destination accumulator tile
    IteratorA iterator_A,                             ///< iterator over A operand in global memory
    IteratorB iterator_B,                             ///< iterator over B operand in global memory
    FragmentC const &src_accum,                       ///< source accumulator tile
    EllIterator &ell_iterator,
    TransformA transform_A = TransformA(),            ///< transformation applied to A fragment
    TransformB transform_B = TransformB()) {          ///< transformation applied to B fragment

    //
    // Prologue
    //

    // Perform accumulation in the 'd' output operand
    accum = src_accum;
    Operator warp_mma;

    int warp_stage_idx_ = 0;

    FragmentA tb_frag_A;
    FragmentB tb_frag_B;
    tb_frag_A.clear();
    tb_frag_B.clear();

    // Load sparse matrix  
    if (is_A_sparse){
      iterator_A.load(tb_frag_A);
    } else {
      iterator_B.load(tb_frag_B);
    }

    // Load dense matrix
    if (is_offset_constant){
      if (is_A_sparse){
        iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator);
      } else {
        iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator);
      }
    } else {
      if (is_A_sparse){
        iterator_B.load_with_ell_index(tb_frag_B, ell_iterator);
      } else {
        iterator_A.load_with_ell_index(tb_frag_A, ell_iterator);
      }
    }

    ++iterator_A;
    ++iterator_B;
    ++ell_iterator;

    // Avoid reading out of bounds
    iterator_A.clear_mask(gemm_k_iterations <= 1);
    iterator_B.clear_mask(gemm_k_iterations <= 1);

    if (is_A_sparse){
      iterator_A.ell_add_mask(ell_iterator.get_blocksize());
    }
    else {
      iterator_B.ell_add_mask(ell_iterator.get_blocksize());
    }

    // Write fragments to lds
    this->smem_iterator_A_.store(transform_A(tb_frag_A));
    this->smem_iterator_B_.store(transform_B(tb_frag_B));
    __syncthreads();

    // Pair of fragments used to overlap shared memory loads and math instructions
    WarpFragmentA warp_frag_A[Base::kWarpGemmIterations == 1 ? 1 : 2];
    WarpFragmentB warp_frag_B[Base::kWarpGemmIterations == 1 ? 1 : 2];

    // Read lds
    this->warp_tile_iterator_A_.set_kgroup_index(0);
    this->warp_tile_iterator_B_.set_kgroup_index(0);
    this->warp_tile_iterator_A_.load(warp_frag_A[0]);
    this->warp_tile_iterator_B_.load(warp_frag_B[0]);
    ++this->warp_tile_iterator_A_;
    ++this->warp_tile_iterator_B_;

    if constexpr (Base::kWarpGemmIterations != 1) { 
      warp_stage_idx_ ^= 1;
    }

    // Mainloop
    HYTLASS_GEMM_LOOP
    for (; gemm_k_iterations > 0; --gemm_k_iterations) {

      //
      // Loop over GEMM K dimension
      //
      HYTLASS_PRAGMA_UNROLL
      for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {

        if (warp_mma_k == Base::kWarpGemmIterations - 1) {

          __syncthreads();

          // Write fragments to lds
          this->smem_iterator_A_.store(transform_A(tb_frag_A));
          this->smem_iterator_B_.store(transform_B(tb_frag_B));

          this->warp_tile_iterator_A_.add_tile_offset(
                {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
          this->warp_tile_iterator_B_.add_tile_offset(
                {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
          __syncthreads();

        }

        if constexpr (Base::kWarpGemmIterations == 1) {
          this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations);
          this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations);
        } else {
          this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
          this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
        }

        this->warp_tile_iterator_A_.load(warp_frag_A[warp_stage_idx_]);
        this->warp_tile_iterator_B_.load(warp_frag_B[warp_stage_idx_]);

        ++this->warp_tile_iterator_A_;
        ++this->warp_tile_iterator_B_;

        if constexpr (Base::kWarpGemmIterations != 1) {
          warp_stage_idx_ ^= 1;
        }

        if (warp_mma_k == 0) {
        
          tb_frag_A.clear();
          tb_frag_B.clear();

          // load sparse matrix
          if (is_A_sparse){
            iterator_A.load(tb_frag_A);
          } else {
            iterator_B.load(tb_frag_B);
          }

          // load dense matrix
          if (is_offset_constant){
            if (is_A_sparse){
              iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator);
            } else {
              iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator);
            }
          } else {
            if (is_A_sparse){
              iterator_B.load_with_ell_index(tb_frag_B, ell_iterator);
            } else {
              iterator_A.load_with_ell_index(tb_frag_A, ell_iterator);
            }
          }

          ++iterator_A;
          ++iterator_B;
          ++ell_iterator;

          iterator_A.clear_mask(gemm_k_iterations <= 2);
          iterator_B.clear_mask(gemm_k_iterations <= 2);

        }

        warp_mma(accum, warp_frag_A[warp_stage_idx_],
                 warp_frag_B[warp_stage_idx_], accum); 

      }
    }
  }

};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace threadblock
} // namespace gemm
} // namespace hytlass

/////////////////////////////////////////////////////////////////////////////////////////////////
