#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dummy_dynamic_transform.hpp"

template <class T,
          class InDesc,
          class WeiDesc,
          class OutDesc,
          class ConvStrides,
          class ConvDilations,
          class InLeftPads,
          class InRightPads>
void device_dummy_dynamic_transform(InDesc,
                                    const Tensor<T>& in_nchw,
                                    WeiDesc,
                                    const Tensor<T>& wei_kcyx,
                                    OutDesc,
                                    Tensor<T>& out_nkhw,
                                    ConvStrides,
                                    ConvDilations,
                                    InLeftPads,
                                    InRightPads,
                                    ck::index_t nrepeat)
{
    using namespace ck;

    using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;

    const auto in_nchw_desc  = make_dynamic_native_tensor_descriptor(to_array(InDesc::GetLengths()),
                                                                    to_array(InDesc::GetStrides()));
    const auto wei_kcyx_desc = make_dynamic_native_tensor_descriptor(
        to_array(WeiDesc::GetLengths()), to_array(WeiDesc::GetStrides()));
    const auto out_nkhw_desc = make_dynamic_native_tensor_descriptor(
        to_array(OutDesc::GetLengths()), to_array(OutDesc::GetStrides()));

    const auto conv_strides   = to_array(ConvStrides{});
    const auto conv_dilations = to_array(ConvDilations{});
    const auto in_left_pads   = to_array(InLeftPads{});
    const auto in_right_pads  = to_array(InRightPads{});

    {
        const auto tensor_descs = map_convolution_into_gemm(wei_kcyx_desc,
                                                            in_nchw_desc,
                                                            out_nkhw_desc,
                                                            conv_strides,
                                                            conv_dilations,
                                                            in_left_pads,
                                                            in_right_pads);

        const auto in_gemmk_gemmn_global_desc = tensor_descs.At(Number<0>{});
        print_array("cpu: in_gemmk_gemmn_global_desc:", in_gemmk_gemmn_global_desc.GetLengths());

        const auto idx0 = MultiIndex<2>({2591, 36991});
        const auto idx1 = in_gemmk_gemmn_global_desc.CalculateLowerIndex(idx0);
        const auto idx2 =
            in_gemmk_gemmn_global_desc.GetLowerTensorDescriptor().CalculateLowerIndex(idx1);

        const index_t offset = in_gemmk_gemmn_global_desc.CalculateOffset(idx0);

        print_array("idx0:", idx0);
        print_array("idx1:", idx1);
        print_array("idx2:", idx2);
        printf("offset %d\n", offset);
    }

    std::size_t data_sz = sizeof(T);
    DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
    DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
    DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());

    in_nchw_device_buf.ToDevice(in_nchw.mData.data());
    wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
    out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());

    constexpr index_t BlockSize = 256;
    constexpr index_t GridSize  = 1;

    printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);

#if 0
    using dummy_transform = DummyDynamicTransform<BlockSize>;

    for(index_t i = 0; i < 5; ++i)
    {
        std::cout << "Start running " << nrepeat << " times..." << std::endl;

        KernelTimer timer;
        timer.Start();

        for(index_t j = 0; j < nrepeat; ++j)
        {
            launch_kernel(run_gridwise_operation<dummy_transform,
                                                 index_t* const,
                                                 index_t* const,
                                                 float* const,
                                                 const DynamicNativeTensorDescriptor<4>,
                                                 const DynamicNativeTensorDescriptor<4>,
                                                 const DynamicNativeTensorDescriptor<4>,
                                                 const Array<index_t, 2>,
                                                 const Array<index_t, 2>,
                                                 const Array<index_t, 2>,
                                                 const Array<index_t, 2>>,
                          dim3(GridSize),
                          dim3(BlockSize),
                          0,
                          0,
                          static_cast<index_t*>(in_nchw_device_buf.GetDeviceBuffer()),
                          static_cast<index_t*>(wei_kcyx_device_buf.GetDeviceBuffer()),
                          static_cast<float*>(out_nkhw_device_buf.GetDeviceBuffer()),
                          wei_kcyx_desc,
                          in_nchw_desc,
                          out_nkhw_desc,
                          conv_strides,
                          conv_dilations,
                          in_left_pads,
                          in_right_pads);
        }
    }
#endif

    out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
