// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_GPU_DaTA_CPP_
#define DLIB_GPU_DaTA_CPP_

// Only things that require ROCM are declared in this cpp file.  Everything else is in the
// gpu_data.h header so that it can operate as "header-only" code when using just the CPU.
#ifdef DLIB_USE_ROCM

#include "gpu_data.h"
#include <iostream>
#include "rocm_utils.h"
#include <cstring>
#include <hip/hip_runtime.h>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    void memcpy (
        gpu_data& dest, 
        const gpu_data& src
    )
    {
        DLIB_CASSERT(dest.size() == src.size());
        if (src.size() == 0 || &dest == &src)
            return;

        memcpy(dest,0, src, 0, src.size());
    }

    void memcpy (
        gpu_data& dest, 
        size_t dest_offset,
        const gpu_data& src,
        size_t src_offset,
        size_t num
    )
    {
        DLIB_CASSERT(dest_offset + num <= dest.size());
        DLIB_CASSERT(src_offset + num <= src.size());
        if (num == 0)
            return;

        // if there is aliasing
        if (&dest == &src && std::max(dest_offset, src_offset) < std::min(dest_offset,src_offset)+num)
        {
            // if they perfectly alias each other then there is nothing to do
            if (dest_offset == src_offset)
                return;
            else
                std::memmove(dest.host()+dest_offset, src.host()+src_offset, sizeof(float)*num);
        }
        else
        {
            // if we write to the entire thing then we can use device_write_only()
            if (dest_offset == 0 && num == dest.size())
            {
                // copy the memory efficiently based on which copy is current in each object.
                if (src.device_ready())
                    CHECK_ROCM(hipMemcpy(dest.device_write_only(), src.device()+src_offset,  num*sizeof(float), hipMemcpyDeviceToDevice));
                else 
                    CHECK_ROCM(hipMemcpy(dest.device_write_only(), src.host()+src_offset,    num*sizeof(float), hipMemcpyHostToDevice));
            }
            else
            {
                // copy the memory efficiently based on which copy is current in each object.
                if (dest.device_ready() && src.device_ready())
                    CHECK_ROCM(hipMemcpy(dest.device()+dest_offset, src.device()+src_offset, num*sizeof(float), hipMemcpyDeviceToDevice));
                else if (!dest.device_ready() && src.device_ready())
                    CHECK_ROCM(hipMemcpy(dest.host()+dest_offset, src.device()+src_offset,   num*sizeof(float), hipMemcpyDeviceToHost));
                else if (dest.device_ready() && !src.device_ready())
                    CHECK_ROCM(hipMemcpy(dest.device()+dest_offset, src.host()+src_offset,   num*sizeof(float), hipMemcpyHostToDevice));
                else 
                    CHECK_ROCM(hipMemcpy(dest.host()+dest_offset, src.host()+src_offset,     num*sizeof(float), hipMemcpyHostToHost));
            }
        }
    }
// ----------------------------------------------------------------------------------------

    void synchronize_stream(hipStream_t stream)
    {
        // while (true)
        // {
        //     hipError_t err = hipStreamQuery(stream);
        //     switch (err)
        //     {
        //     case hipSuccess: return;      // now we are synchronized
        //     case hipErrorNotReady: break; // continue waiting
        //     default: CHECK_ROCM(err);      // unexpected error: throw
        //     }
        // }

        // CHECK_ROCM(hipStreamSynchronize(stream));
        CHECK_ROCM(hipDeviceSynchronize());
    }

    void gpu_data::
    wait_for_transfer_to_finish() const
    {
        if (have_active_transfer)
        {
            synchronize_stream((hipStream_t)hip_stream.get());

            have_active_transfer = false;
            // Check for errors.  These calls to hipGetLastError() are what help us find
            // out if our kernel launches have been failing.
            CHECK_ROCM(hipGetLastError());
        }
    }

    void gpu_data::
    copy_to_device() const
    {
        // We want transfers to the device to always be concurrent with any device
        // computation.  So we use our non-default stream to do the transfer.
        async_copy_to_device();
        wait_for_transfer_to_finish();
    }

    void gpu_data::
    copy_to_host() const
    {
        if (!host_current)
        {
            wait_for_transfer_to_finish();
            CHECK_ROCM(hipMemcpy(data_host.get(), data_device.get(), data_size*sizeof(float), hipMemcpyDeviceToHost));
            host_current = true;
            // At this point we know our RAM block isn't in use because hipMemcpy()
            // implicitly syncs with the device. 
            device_in_use = false;
            // Check for errors.  These calls to hipGetLastError() are what help us find
            // out if our kernel launches have been failing.
            CHECK_ROCM(hipGetLastError());
        }
    }

    void gpu_data::
    async_copy_to_device() const
    {
        if (!device_current)
        {
            if (device_in_use)
            {
                // Wait for any possible ROCM kernels that might be using our memory block to
                // complete before we overwrite the memory.
                synchronize_stream(0);
                device_in_use = false;
            }
            CHECK_ROCM(hipMemcpyAsync(data_device.get(), data_host.get(), data_size*sizeof(float), hipMemcpyHostToDevice, (hipStream_t)hip_stream.get()));
            have_active_transfer = true;
            device_current = true;
        }
    }

    void gpu_data::
    set_size(
        size_t new_size
    )
    {
        if (new_size == 0)
        {
            if (device_in_use)
            {
                // Wait for any possible ROCM kernels that might be using our memory block to
                // complete before we free the memory.
                synchronize_stream(0);
                device_in_use = false;
            }
            wait_for_transfer_to_finish();
            data_size = 0;
            host_current = true;
            device_current = true;
            device_in_use = false;
            data_host.reset();
            data_device.reset();
        }
        else if (new_size != data_size)
        {
            if (device_in_use)
            {
                // Wait for any possible ROCM kernels that might be using our memory block to
                // complete before we free the memory.
                synchronize_stream(0);
                device_in_use = false;
            }
            wait_for_transfer_to_finish();
            data_size = new_size;
            host_current = true;
            device_current = true;
            device_in_use = false;

            try
            {
                CHECK_ROCM(hipGetDevice(&the_device_id));

                // free memory blocks before we allocate new ones.
                data_host.reset();
                data_device.reset();

                void* data;
                CHECK_ROCM(hipMallocHost(&data, new_size*sizeof(float)));
                // Note that we don't throw exceptions since the free calls are invariably
                // called in destructors.  They also shouldn't fail anyway unless someone
                // is resetting the GPU card in the middle of their program.
                data_host.reset((float*)data, [](float* ptr){
                    auto err = hipHostFree(ptr);
                    if(err!=hipSuccess)
                        std::cerr << "hipHostFree() failed. Reason: " << hipGetErrorString(err) << std::endl;
                });

                CHECK_ROCM(hipMalloc(&data, new_size*sizeof(float)));
                data_device.reset((float*)data, [](float* ptr){
                    auto err = hipFree(ptr);
                    if(err!=hipSuccess)
                        std::cerr << "hipFree() failed. Reason: " << hipGetErrorString(err) << std::endl;
                });

                if (!hip_stream)
                {
                    hipStream_t cstream;
                    CHECK_ROCM(hipStreamCreateWithFlags(&cstream, hipStreamNonBlocking));
                    hip_stream.reset(cstream, [](void* ptr){
                        auto err = hipStreamDestroy((hipStream_t)ptr);
                        if(err!=hipSuccess)
                            std::cerr << "hipStreamDestroy() failed. Reason: " << hipGetErrorString(err) << std::endl;
                    });
                }

            }
            catch(...)
            {
                set_size(0);
                throw;
            }
        }
    }

// ----------------------------------------------------------------------------------------
}

#endif // DLIB_USE_ROCM

#endif // DLIB_GPU_DaTA_CPP_

