// Copyright (C) 2017  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_ROCM_DATA_PTR_CPP_
#define DLIB_DNN_ROCM_DATA_PTR_CPP_

#ifdef DLIB_USE_ROCM

#include "rocm_data_ptr.h"
#include "rocm_utils.h"

namespace dlib
{
    namespace rocm 
    {

    // ----------------------------------------------------------------------------------------

        weak_rocm_data_void_ptr::
        weak_rocm_data_void_ptr(
            const rocm_data_void_ptr& ptr
        ) : num(ptr.num), pdata(ptr.pdata)
        {

        }

    // ----------------------------------------------------------------------------------------

        rocm_data_void_ptr weak_rocm_data_void_ptr::
        lock() const 
        {
            auto ptr = pdata.lock();
            if (ptr)
            {
                rocm_data_void_ptr temp;
                temp.pdata = ptr;
                temp.num = num;
                return temp;
            }
            else
            {
                return rocm_data_void_ptr();
            }
        }

    // -----------------------------------------------------------------------------------
    // -----------------------------------------------------------------------------------

        rocm_data_void_ptr::
        rocm_data_void_ptr(
            size_t n
        ) : num(n)
        {
            if (n == 0)
                return;

            void* data = nullptr;

            CHECK_ROCM(hipMalloc(&data, n));
            pdata.reset(data, [](void* ptr){
                auto err = hipFree(ptr);
                if(err!=hipSuccess)
                std::cerr << "hipFree() failed. Reason: " << hipGetErrorString(err) << std::endl;
            });
        }

    // ------------------------------------------------------------------------------------

        void memcpy(
            void* dest,
            const rocm_data_void_ptr& src,
            const size_t num
        )
        {
            DLIB_ASSERT(num <= src.size());
            if (src.size() != 0)
            {
                CHECK_ROCM(hipMemcpy(dest, src.data(),  num, hipMemcpyDefault));
            }
        }

    // ------------------------------------------------------------------------------------

        void memcpy(
            void* dest,
            const rocm_data_void_ptr& src
        )
        {
            memcpy(dest, src, src.size());
        }

    // ------------------------------------------------------------------------------------

        void memcpy(
            rocm_data_void_ptr dest, 
            const void* src,
            const size_t num
        )
        {
            DLIB_ASSERT(num <= dest.size());
            if (dest.size() != 0)
            {
                CHECK_ROCM(hipMemcpy(dest.data(), src, num, hipMemcpyDefault));
            }
        }

    // ------------------------------------------------------------------------------------

        void memcpy(
            rocm_data_void_ptr dest, 
            const void* src
        )
        {
            memcpy(dest,src,dest.size());
        }

    // ------------------------------------------------------------------------------------

        class miopen_device_buffer
        {
        public:
            // not copyable
            miopen_device_buffer(const miopen_device_buffer&) = delete;
            miopen_device_buffer& operator=(const miopen_device_buffer&) = delete;

            miopen_device_buffer()
            {
                buffers.resize(16);
            }
            ~miopen_device_buffer()
            {
            }

            rocm_data_void_ptr get (
                size_t size
            )
            {
                int new_device_id;
                CHECK_ROCM(hipGetDevice(&new_device_id));
                // make room for more devices if needed
                if (new_device_id >= (long)buffers.size())
                    buffers.resize(new_device_id+16);

                // If we don't have a buffer already for this device then make one, or if it's too
                // small, make a bigger one.
                rocm_data_void_ptr buff = buffers[new_device_id].lock();
                if (!buff || buff.size() < size)
                {
                    buff = rocm_data_void_ptr(size);
                    buffers[new_device_id] = buff;
                }

                // Finally, return the buffer for the current device
                return buff;
            }

        private:

            std::vector<weak_rocm_data_void_ptr> buffers;
        };

    // ----------------------------------------------------------------------------------------

        rocm_data_void_ptr device_global_buffer(size_t size) 
        {
            thread_local miopen_device_buffer buffer;
            return buffer.get(size);
        }

    // ------------------------------------------------------------------------------------

    }  
}

#endif // DLIB_USE_ROCM

#endif // DLIB_DNN_ROCM_DATA_PTR_CPP_


