import numpy import cupy from cupy.cuda import cublas from cupy.cuda import device from cupy.linalg import _util from cupyx.scipy.linalg import _uarray @_uarray.implements('solve_triangular') def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, check_finite=False): """Solve the equation a x = b for x, assuming a is a triangular matrix. Args: a (cupy.ndarray): The matrix with dimension ``(M, M)``. b (cupy.ndarray): The matrix with dimension ``(M,)`` or ``(M, N)``. lower (bool): Use only data contained in the lower triangle of ``a``. Default is to use upper triangle. trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: - *'0'* or *'N'* -- :math:`a x = b` - *'1'* or *'T'* -- :math:`a^T x = b` - *'2'* or *'C'* -- :math:`a^H x = b` unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are assumed to be 1 and will not be referenced. overwrite_b (bool): Allow overwriting data in b (may enhance performance) check_finite (bool): Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. Returns: cupy.ndarray: The matrix with dimension ``(M,)`` or ``(M, N)``. .. seealso:: :func:`scipy.linalg.solve_triangular` """ _util._assert_cupy_array(a, b) if len(a.shape) != 2 or a.shape[0] != a.shape[1]: raise ValueError('expected square matrix') if len(a) != len(b): raise ValueError('incompatible dimensions') # Cast to float32 or float64 if a.dtype.char in 'fd': dtype = a.dtype else: dtype = numpy.promote_types(a.dtype.char, 'f') a = cupy.array(a, dtype=dtype, order='F', copy=False) b = cupy.array(b, dtype=dtype, order='F', copy=(not overwrite_b)) if check_finite: if a.dtype.kind == 'f' and not cupy.isfinite(a).all(): raise ValueError( 'array must not contain infs or NaNs') if b.dtype.kind == 'f' and not cupy.isfinite(b).all(): raise ValueError( 'array must not contain infs or NaNs') m, n = (b.size, 1) if b.ndim == 1 else b.shape cublas_handle = device.get_cublas_handle() if dtype == 'f': trsm = cublas.strsm elif dtype == 'd': trsm = cublas.dtrsm elif dtype == 'F': trsm = cublas.ctrsm else: # dtype == 'D' trsm = cublas.ztrsm one = numpy.array(1, dtype=dtype) if lower: uplo = cublas.CUBLAS_FILL_MODE_LOWER else: uplo = cublas.CUBLAS_FILL_MODE_UPPER if trans == 'N': trans = cublas.CUBLAS_OP_N elif trans == 'T': trans = cublas.CUBLAS_OP_T elif trans == 'C': trans = cublas.CUBLAS_OP_C if unit_diagonal: diag = cublas.CUBLAS_DIAG_UNIT else: diag = cublas.CUBLAS_DIAG_NON_UNIT trsm( cublas_handle, cublas.CUBLAS_SIDE_LEFT, uplo, trans, diag, m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) return b