#include "hip/hip_runtime.h"
// Copyright (C) 2017  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_HIPSOLVER_CU_
#define DLIB_DNN_HIPSOLVER_CU_

#ifdef DLIB_USE_ROCM

#include "hipsolver_dlibapi.h"
#include <hipsolver/hipsolver.h>
#include "rocm_utils.h"

// ----------------------------------------------------------------------------------------

static const char* hipsolver_get_error_string(hipsolverStatus_t s)
{
    switch(s)
    {
        case HIPSOLVER_STATUS_NOT_INITIALIZED:
            return "ROCM Runtime API initialization failed.";
        case HIPSOLVER_STATUS_ALLOC_FAILED:
            return "ROCM Resources could not be allocated.";
        default:
            return "A call to hipsolver failed";
    }
}

// Check the return value of a call to the hipsolver runtime for an error condition.
#define CHECK_HIPSOLVER(call)                                                      \
do{                                                                              \
    const hipsolverStatus_t error = call;                                         \
    if (error != HIPSOLVER_STATUS_SUCCESS)                                        \
    {                                                                          \
        std::ostringstream sout;                                               \
        sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
        sout << "code: " << error << ", reason: " << hipsolver_get_error_string(error);\
        throw dlib::hipsolver_error(sout.str());                                \
    }                                                                          \
}while(false)

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

namespace dlib
{
    namespace rocm 
    {

    // -----------------------------------------------------------------------------------

        class hipsolver_context
        {
        public:
            // not copyable
            hipsolver_context(const hipsolver_context&) = delete;
            hipsolver_context& operator=(const hipsolver_context&) = delete;

            hipsolver_context()
            {
                handles.resize(16);
            }
            ~hipsolver_context()
            {
                for (auto h : handles)
                {
                    if (h)
                        hipsolverDnDestroy(h);
                }
            }

            hipsolverDnHandle_t get_handle (
            )  
            { 
                int new_device_id;
                CHECK_ROCM(hipGetDevice(&new_device_id));
                // make room for more devices if needed
                if (new_device_id >= (long)handles.size())
                    handles.resize(new_device_id+16);

                // If we don't have a handle already for this device then make one
                if (!handles[new_device_id])
                    CHECK_HIPSOLVER(hipsolverDnCreate(&handles[new_device_id]));

                // Finally, return the handle for the current device
                return handles[new_device_id];
            }

        private:

            std::vector<hipsolverDnHandle_t> handles;
        };

        static hipsolverDnHandle_t context()
        {
            thread_local hipsolver_context c;
            return c.get_handle();
        }

    // ------------------------------------------------------------------------------------
    // ------------------------------------------------------------------------------------
    // ------------------------------------------------------------------------------------

        __global__ void _rocm_set_to_identity_matrix(float* m, size_t nr)
        {
            for (auto j : grid_stride_range(0, nr*nr))
            {
                if (j%(nr+1) == 0)
                    m[j] = 1;
                else
                    m[j] = 0;
            }
        }

        void set_to_identity_matrix (
            tensor& m 
        )
        {
            DLIB_CASSERT(m.size() == m.num_samples()*m.num_samples());
            launch_kernel(_rocm_set_to_identity_matrix, max_jobs(m.size()), m.device(), m.num_samples());
        }

    // ------------------------------------------------------------------------------------

        inv::~inv()
        {
            sync_if_needed();
        }

    // ------------------------------------------------------------------------------------

        void inv::
        operator() (
            const tensor& m_,
            resizable_tensor& out
        )
        {
            DLIB_CASSERT(m_.size() == m_.num_samples()*m_.num_samples(), "Input matrix must be square if you want to invert it.");
            m = m_;

            out.copy_size(m);
            set_to_identity_matrix(out);

            const int nc = m.num_samples();
            int Lwork;
            CHECK_HIPSOLVER(hipsolverDnSgetrf_bufferSize(context(), nc , nc, m.device(), nc, &Lwork));

            if (Lwork > (int)workspace.size())
            {
                sync_if_needed();
                workspace = rocm_data_ptr<float>(Lwork);
            }
            if (nc > (int)Ipiv.size())
            {
                sync_if_needed();
                Ipiv = rocm_data_ptr<int>(nc);
            }
            if (info.size() != 1)
            {
                info = rocm_data_ptr<int>(1);
            }

            CHECK_HIPSOLVER(hipsolverDnSgetrf(context(), nc, nc, m.device(), nc, workspace, Ipiv, info));
            CHECK_HIPSOLVER(hipsolverDnSgetrs(context(), HIPSOLVER_OP_N, nc, nc, m.device(), nc, Ipiv, out.device(), nc, info));
            did_work_lately = true;
        }

    // ------------------------------------------------------------------------------------

        int inv::
        get_last_status(
        )
        {
            std::vector<int> linfo; 
            memcpy(linfo, info);
            if (linfo.size() != 0)
                return linfo[0];
            else
                return 0;
        }

    // ------------------------------------------------------------------------------------

        void inv::
        sync_if_needed()
        {
            if (did_work_lately)
            {
                did_work_lately = false;
                // make sure we wait until any previous kernel launches have finished
                // before we do something like deallocate the GPU memory.
                hipDeviceSynchronize();
            }
        }

    // ------------------------------------------------------------------------------------

    }  
}

#endif // DLIB_USE_ROCM

#endif // DLIB_DNN_HIPSOLVER_CU_


