#ifndef INCLUDE_GUARD_CUPY_CUSOLVER_H #define INCLUDE_GUARD_CUPY_CUSOLVER_H #include #if !defined(CUPY_NO_CUDA) && !defined(CUPY_USE_HIP) #include "cuda/cupy_cusolver.h" #elif defined(CUPY_USE_HIP) // #if !defined(CUPY_NO_CUDA) && !defined(CUPY_USE_HIP) #include "hip/cupy_rocsolver.h" #else // #if !defined(CUPY_NO_CUDA) && !defined(CUPY_USE_HIP) #include "stub/cupy_cusolver.h" #endif // #if !defined(CUPY_NO_CUDA) && !defined(CUPY_USE_HIP) #if !defined(CUPY_USE_HIP) /* * loop-based batched gesvd (only used on CUDA) */ template using gesvd = cusolverStatus_t (*)(cusolverDnHandle_t, signed char, signed char, int, int, T1*, int, T2*, T1*, int, T1*, int, T1*, int, T2*, int*); template struct gesvd_func { gesvd ptr; }; template<> struct gesvd_func { gesvd ptr = cusolverDnSgesvd; }; template<> struct gesvd_func { gesvd ptr = cusolverDnDgesvd; }; template<> struct gesvd_func { gesvd ptr = cusolverDnCgesvd; }; template<> struct gesvd_func { gesvd ptr = cusolverDnZgesvd; }; template int gesvd_loop( intptr_t handle, char jobu, char jobvt, int m, int n, intptr_t a_ptr, intptr_t s_ptr, intptr_t u_ptr, intptr_t vt_ptr, intptr_t w_ptr, int buffersize, intptr_t info_ptr, int batch_size) { /* * Assumptions: * 1. the stream is set prior to calling this function * 2. the workspace is reused in the loop */ cusolverStatus_t status; int k = (m::value) || (std::is_same::value), float, /* double or cuDoubleComplex */ double>::type real_type; T* A = reinterpret_cast(a_ptr); real_type* S = reinterpret_cast(s_ptr); T* U = reinterpret_cast(u_ptr); T* VT = reinterpret_cast(vt_ptr); T* Work = reinterpret_cast(w_ptr); int* devInfo = reinterpret_cast(info_ptr); // we can't use "if constexpr" to do a compile-time branch selection as it's C++17 only, // so we use custom traits instead gesvd func = gesvd_func().ptr; for (int i=0; i(handle), jobu, jobvt, m, n, A, m, S, U, m, VT, n, Work, buffersize, NULL, devInfo); if (status != 0) break; A += m * n; S += k; U += (jobu=='A' ? m*m : (jobu=='S' ? m*k : /* jobu=='O' or 'N' */ 0)); VT += (jobvt=='A' ? n*n : (jobvt=='S' ? n*k : /* jobvt=='O' or 'N' */ 0)); devInfo += 1; } return status; } /* * loop-based batched geqrf (only used on CUDA) */ template using geqrf = cusolverStatus_t (*)(cusolverDnHandle_t, int, int, T*, int, T*, T*, int, int*); template struct geqrf_func { geqrf ptr; }; template<> struct geqrf_func { geqrf ptr = cusolverDnSgeqrf; }; template<> struct geqrf_func { geqrf ptr = cusolverDnDgeqrf; }; template<> struct geqrf_func { geqrf ptr = cusolverDnCgeqrf; }; template<> struct geqrf_func { geqrf ptr = cusolverDnZgeqrf; }; template int geqrf_loop( intptr_t handle, int m, int n, intptr_t a_ptr, int lda, intptr_t tau_ptr, intptr_t w_ptr, int buffersize, intptr_t info_ptr, int batch_size) { /* * Assumptions: * 1. the stream is set prior to calling this function * 2. the workspace is reused in the loop */ cusolverStatus_t status; int k = (m(a_ptr); T* Tau = reinterpret_cast(tau_ptr); T* Work = reinterpret_cast(w_ptr); int* devInfo = reinterpret_cast(info_ptr); // we can't use "if constexpr" to do a compile-time branch selection as it's C++17 only, // so we use custom traits instead geqrf func = geqrf_func().ptr; for (int i=0; i(handle), m, n, A, lda, Tau, Work, buffersize, devInfo); if (status != 0) break; A += m * n; Tau += k; devInfo += 1; } return status; } #else template int gesvd_loop( intptr_t handle, char jobu, char jobvt, int m, int n, intptr_t a_ptr, intptr_t s_ptr, intptr_t u_ptr, intptr_t vt_ptr, intptr_t w_ptr, int buffersize, intptr_t info_ptr, int batch_size) { // we need a dummy stub for HIP as it's not used return 0; } /* * batched geqrf (only used on HIP) */ template using geqrf = cusolverStatus_t (*)(cusolverDnHandle_t, int, int, T* const[], int, T*, long int, int); template struct geqrf_func { geqrf ptr; }; template<> struct geqrf_func { geqrf ptr = rocsolver_sgeqrf_batched; }; template<> struct geqrf_func { geqrf ptr = rocsolver_dgeqrf_batched; }; // we need the correct func pointer here, so can't cast! template<> struct geqrf_func { geqrf ptr = rocsolver_cgeqrf_batched; }; template<> struct geqrf_func { geqrf ptr = rocsolver_zgeqrf_batched; }; template int geqrf_loop( intptr_t handle, int m, int n, intptr_t a_ptr, int lda, intptr_t tau_ptr, intptr_t w_ptr, int buffersize, intptr_t info_ptr, int batch_size) { /* * Assumptions: * 1. the stream is set prior to calling this function * 2. ignore w_ptr, buffersize, and info_ptr as rocSOLVER does not need them */ cusolverStatus_t status; // we can't use "if constexpr" to do a compile-time branch selection as it's C++17 only, // so we use custom traits instead typedef typename std::conditional< std::is_floating_point::value, T, typename std::conditional::value, rocblas_float_complex, rocblas_double_complex>::type >::type data_type; geqrf func = geqrf_func().ptr; data_type* const* A = reinterpret_cast(a_ptr); data_type* Tau = reinterpret_cast(tau_ptr); int k = (m using orgqr = cusolverStatus_t (*)(cusolverDnHandle_t, int, int, int, T*, int, const T*, T*, int, int*); template struct orgqr_func { orgqr ptr; }; template<> struct orgqr_func { orgqr ptr = cusolverDnSorgqr; }; template<> struct orgqr_func { orgqr ptr = cusolverDnDorgqr; }; template<> struct orgqr_func { orgqr ptr = cusolverDnCungqr; }; template<> struct orgqr_func { orgqr ptr = cusolverDnZungqr; }; template int orgqr_loop( intptr_t handle, int m, int n, int k, intptr_t a_ptr, int lda, intptr_t tau_ptr, intptr_t w_ptr, int buffersize, intptr_t info_ptr, int batch_size, int origin_n) { /* * Assumptions: * 1. the stream is set prior to calling this function * 2. the workspace is reused in the loop */ cusolverStatus_t status; T* A = reinterpret_cast(a_ptr); const T* Tau = reinterpret_cast(tau_ptr); T* Work = reinterpret_cast(w_ptr); int* devInfo = reinterpret_cast(info_ptr); // we can't use "if constexpr" to do a compile-time branch selection as it's C++17 only, // so we use custom traits instead orgqr func = orgqr_func().ptr; for (int i=0; i(handle), m, n, k, A, lda, Tau, Work, buffersize, devInfo); if (status != 0) break; A += m * origin_n; Tau += k; devInfo += 1; } return status; } #endif // #ifndef INCLUDE_GUARD_CUPY_CUSOLVER_H