#pragma once
#include "src/fastertransformer/ck_extensions/include/ck_extensions/ft_gemm_configs.h"
#include "src/fastertransformer/utils/activation_types.h"
// #include "hip/hip_runtime.h"

namespace fastertransformer {

template<typename T, /*The type used for activations/scales/compute*/
         typename WeightType /* The type for the MoE weights */>
class MoeGemmRunner {
public:
    MoeGemmRunner();

    void moe_gemm_bias_act(const T*          A,
                           const WeightType* B,
                           const T*          weight_scales,
                           const T*          biases,
                           T*                C,
                           int64_t*          total_rows_before_expert,
                           int64_t           total_rows,
                           int64_t           gemm_n,
                           int64_t           gemm_k,
                           int               num_experts,
                           ActivationType    activation_type
                           );
                        //    hipStream_t      stream);

    void moe_gemm(const T*          A,
                  const WeightType* B,
                  const T*          weight_scales,
                  T*                C,
                  int64_t*          total_rows_before_expert,
                  int64_t           total_rows,
                  int64_t           gemm_n,
                  int64_t           gemm_k,
                  int               num_experts
                  );
                //   hipStream_t      stream);

private:
    template<typename EpilogueTag>
    void dispatch_to_arch(const T*          A,
                          const WeightType* B,
                          const T*          weight_scales,
                          const T*          biases,
                          T*                C,
                          int64_t*          total_rows_before_expert,
                          int64_t           total_rows,
                          int64_t           gemm_n,
                          int64_t           gemm_k,
                          int               num_experts,
                          CutlassGemmConfig gemm_config,
                        //   hipStream_t      stream,
                          int*              occupancy = nullptr);

    template<typename EpilogueTag>
    void run_gemm(const T*          A,
                  const WeightType* B,
                  const T*          weight_scales,
                  const T*          biases,
                  T*                C,
                  int64_t*          total_rows_before_expert,
                  int64_t           total_rows,
                  int64_t           gemm_n,
                  int64_t           gemm_k,
                  int               num_experts
                  );
                //   hipStream_t      stream);

private:
    int sm_;
    int multi_processor_count_;
};

}  // namespace fastertransformer