// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_RocBLAS_CPP_
#define DLIB_DNN_RocBLAS_CPP_

#ifdef DLIB_USE_ROCM

#include "rocblas_dlibapi.h"
#include "rocm_utils.h"

#include <rocblas/rocblas.h>
#include <vector>

static const char* rocblas_get_error_string(rocblas_status s)
{
    switch(s)
    {
        case rocblas_status_invalid_handle:
            return "ROCM Runtime API initialization failed.";
        case rocblas_status_memory_error:
            return "ROCM Resources could not be allocated.";
        default:
            return "A call to hipblas failed";
    }
}

// Check the return value of a call to the rocblas runtime for an error condition.
#define CHECK_ROCBLAS(call)                                                      \
do{                                                                              \
    const rocblas_status error = call;                                         \
    if (error != rocblas_status_success)                                        \
    {                                                                          \
        std::ostringstream sout;                                               \
        sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
        sout << "code: " << error << ", reason: " << rocblas_get_error_string(error);\
        throw dlib::rocblas_error(sout.str());                            \
    }                                                                          \
}while(false)

namespace dlib
{
    namespace rocm 
    {

    // -----------------------------------------------------------------------------------

        class rocblas_context
        {
        public:
            // not copyable
            rocblas_context(const rocblas_context&) = delete;
            rocblas_context& operator=(const rocblas_context&) = delete;

            rocblas_context()
            {
                handles.resize(16);
            }
            ~rocblas_context()
            {
                for (auto h : handles)
                {
                    if (h)
                        rocblas_destroy_handle(h);
                }
            }

            rocblas_handle get_handle (
            )  
            { 
                int new_device_id;
                CHECK_ROCM(hipGetDevice(&new_device_id));

                // make room for more devices if needed
                if (new_device_id >= (long)handles.size())
                    handles.resize(new_device_id+16);

                // If we don't have a handle already for this device then make one
                if (!handles[new_device_id]){
                    CHECK_ROCBLAS(rocblas_create_handle(&handles[new_device_id]));
                }

                // Finally, return the handle for the current device
                return handles[new_device_id];
            }

        private:

            std::vector<rocblas_handle> handles;
        };

        static rocblas_handle context()
        {
            thread_local rocblas_context c;
            return c.get_handle();
        }

    // -----------------------------------------------------------------------------------
#include "../matrix.h"

        void gemm (
            float beta,
            tensor& dest,
            float alpha,
            const tensor& lhs,
            bool trans_lhs,
            const tensor& rhs,
            bool trans_rhs
        )
        {
            // Recall that BLAS uses column major order so to deal with that we flip the
            // order of the lhs and rhs arguments.
            const auto transa = trans_lhs ? rocblas_operation_transpose : rocblas_operation_none;
            const auto transb = trans_rhs ? rocblas_operation_transpose : rocblas_operation_none;

            const int dest_nr = dest.num_samples();
            const int dest_nc = dest.size()/dest_nr;
            const int lhs_nr = lhs.num_samples();
            const int lhs_nc = lhs.size()/lhs_nr;
            const int rhs_nr = rhs.num_samples();
            const int rhs_nc = rhs.size()/rhs_nr;
            if (trans_lhs && trans_rhs)
            {
                DLIB_ASSERT( dest_nr == lhs_nc &&
                              dest_nc == rhs_nr &&
                              lhs_nr == rhs_nc)
            }
            else if (!trans_lhs && trans_rhs)
            {
                DLIB_ASSERT( dest_nr == lhs_nr &&
                              dest_nc == rhs_nr &&
                              lhs_nc == rhs_nc)
            }
            else if (trans_lhs && !trans_rhs)
            {
                DLIB_ASSERT( dest_nr == lhs_nc &&
                              dest_nc == rhs_nc &&
                              lhs_nr == rhs_nr)
            }
            else
            {
                DLIB_ASSERT( dest_nr == lhs_nr &&
                              dest_nc == rhs_nc &&
                              lhs_nc == rhs_nr)
            }

            const int k = trans_rhs ? rhs_nc : rhs_nr;
            CHECK_ROCBLAS(rocblas_sgemm(context(),
                  transb,
                  transa,
                  dest_nc, dest_nr, k,
                  &alpha,
                  rhs.device(), rhs_nc,
                  lhs.device(), lhs_nc,
                  &beta,
                  dest.device(),dest_nc));
        }

    // ------------------------------------------------------------------------------------

    }  
}

#endif // DLIB_USE_ROCM

#endif // DLIB_DNN_RocBLAS_CPP_



