

include'hipfort_rocblas_enums.f'
include'hipfort_rocfft_enums.f'
include'hipfort_rocsparse_enums.f'
include'hipfort_rocsolver_enums.f'
module cudafor_check

contains

   subroutine cudaCheck(cudaError_t)
      use cudafor_enums
      implicit none

      integer(kind(cudaSuccess)) :: cudaError_t

      if (cudaError_t /= cudaSuccess) then
         write (*, *) "CU ERROR: Error code = ", cudaError_t
         call exit(cudaError_t)
      end if
   end subroutine cudaCheck

  
  ! HIP math libs
  ! TODO: Currently, only AMDGPU is supported
  
  subroutine cublasCheck(cublasError_t)
    use cudafor_cublas_enums

    implicit none

    integer(kind(CUBLAS_STATUS_SUCCESS)) :: cublasError_t

    if(cublasError_t /= CUBLAS_STATUS_SUCCESS)then
       write(*,*) "CUBLAS ERROR: Error code = ", cublasError_t
       call exit(cublasError_t)
    end if
  end subroutine cublasCheck

  subroutine cufftCheck(cufft_status)
    use cudafor_cufft_enums

    implicit none

    integer(kind(cufft_success)) :: cufft_status

    if(cufft_status /= cufft_success)then
       write(*,*) "CUFFT ERROR: Error code = ", cufft_status
       call exit(cufft_status)
    end if
  end subroutine cufftCheck
  
  subroutine cusparseCheck(cusparseError_t)
    use cudafor_cusparse_enums

    implicit none

    integer(kind(CUSPARSE_STATUS_SUCCESS)) :: cusparseError_t

    if(cusparseError_t /= CUSPARSE_STATUS_SUCCESS)then
       write(*,*) "CUSPARSE ERROR: Error code = ", cusparseError_t
       call exit(cusparseError_t)
    end if
  end subroutine cusparseCheck
  
  subroutine cusolverCheck(cusolverError_t)
    use cudafor_cusolver_enums

    implicit none

    integer(kind(CUSOLVER_STATUS_SUCCESS)) :: cusolverError_t

    if(cusolverError_t /= CUSOLVER_STATUS_SUCCESS)then
       write(*,*) "CUSOLVER ERROR: Error code = ", cusolverError_t
       call exit(cusolverError_t)
    end if
  end subroutine cusolverCheck
  
  subroutine rocblasCheck(rocblasError_t)
    use hipfort_rocblas_enums

    implicit none

    integer(kind(ROCBLAS_STATUS_SUCCESS)) :: rocblasError_t

    if(rocblasError_t /= ROCBLAS_STATUS_SUCCESS)then
       write(*,*) "ROCBLAS ERROR: Error code = ", rocblasError_t
       call exit(rocblasError_t)
    end if
  end subroutine rocblasCheck

  ! ROCm math libs

  subroutine rocfftCheck(rocfft_status)
    use hipfort_rocfft_enums

    implicit none

    integer(kind(rocfft_status_success)) :: rocfft_status

    if(rocfft_status /= rocfft_status_success)then
       write(*,*) "ROCFFT ERROR: Error code = ", rocfft_status
       call exit(rocfft_status)
    end if
  end subroutine rocfftCheck
  
  subroutine rocsparseCheck(rocsparseError_t)
    use hipfort_rocsparse_enums

    implicit none

    integer(kind(ROCSPARSE_STATUS_SUCCESS)) :: rocsparseError_t

    if(rocsparseError_t /= ROCSPARSE_STATUS_SUCCESS)then
       write(*,*) "ROCSPARSE ERROR: Error code = ", rocsparseError_t
       call exit(rocsparseError_t)
    end if
  end subroutine rocsparseCheck
  
  subroutine rocsolverCheck(rocsolverError_t)
    use hipfort_rocblas_enums
    use hipfort_rocsolver_enums

    implicit none

    integer(kind(ROCBLAS_STATUS_SUCCESS)) :: rocsolverError_t

    if(rocsolverError_t /= ROCBLAS_STATUS_SUCCESS)then
       write(*,*) "ROCSOLVER ERROR: Error code = ", rocsolverError_t
       call exit(rocsolverError_t)
    end if
  end subroutine rocsolverCheck

end module cudafor_check