import numpy import cupy from cupy import _core from cupy._core import internal from cupyx.scipy.ndimage import _util from cupyx.scipy import special def _get_output_fourier(output, input, complex_only=False): types = [cupy.complex64, cupy.complex128] if not complex_only: types += [cupy.float32, cupy.float64] if output is None: if input.dtype in types: output = cupy.empty(input.shape, dtype=input.dtype) else: output = cupy.empty(input.shape, dtype=types[-1]) elif type(output) is type: if output not in types: raise RuntimeError('output type not supported') output = cupy.empty(input.shape, dtype=output) elif output.shape != input.shape: raise RuntimeError('output shape not correct') return output def _reshape_nd(arr, ndim, axis): """Promote a 1d array to ndim with non-singleton size along axis.""" nd_shape = (1,) * axis + (arr.size,) + (1,) * (ndim - axis - 1) return arr.reshape(nd_shape) def fourier_gaussian(input, sigma, n=-1, axis=-1, output=None): """Multidimensional Gaussian shift filter. The array is multiplied with the Fourier transform of a (separable) Gaussian kernel. Args: input (cupy.ndarray): The input array. sigma (float or sequence of float): The sigma of the Gaussian kernel. If a float, `sigma` is the same for all axes. If a sequence, `sigma` has to contain one value for each axis. n (int, optional): If `n` is negative (default), then the input is assumed to be the result of a complex fft. If `n` is larger than or equal to zero, the input is assumed to be the result of a real fft, and `n` gives the length of the array before transformation along the real transform direction. axis (int, optional): The axis of the real transform (only used when ``n > -1``). output (cupy.ndarray, optional): If given, the result of shifting the input is placed in this array. Returns: output (cupy.ndarray): The filtered output. """ ndim = input.ndim output = _get_output_fourier(output, input) axis = internal._normalize_axis_index(axis, ndim) sigmas = _util._fix_sequence_arg(sigma, ndim, 'sigma') _core.elementwise_copy(input, output) for ax, (sigmak, ax_size) in enumerate(zip(sigmas, output.shape)): # compute the frequency grid in Hz if ax == axis and n > 0: arr = cupy.arange(ax_size, dtype=output.real.dtype) arr /= n else: arr = cupy.fft.fftfreq(ax_size) arr = arr.astype(output.real.dtype, copy=False) # compute the Gaussian weights arr *= arr scale = sigmak * sigmak / -2 arr *= (4 * numpy.pi * numpy.pi) * scale cupy.exp(arr, out=arr) # reshape for broadcasting arr = _reshape_nd(arr, ndim=ndim, axis=ax) output *= arr return output def fourier_uniform(input, size, n=-1, axis=-1, output=None): """Multidimensional uniform shift filter. The array is multiplied with the Fourier transform of a box of given size. Args: input (cupy.ndarray): The input array. size (float or sequence of float): The sigma of the box used for filtering. If a float, `size` is the same for all axes. If a sequence, `size` has to contain one value for each axis. n (int, optional): If `n` is negative (default), then the input is assumed to be the result of a complex fft. If `n` is larger than or equal to zero, the input is assumed to be the result of a real fft, and `n` gives the length of the array before transformation along the real transform direction. axis (int, optional): The axis of the real transform (only used when ``n > -1``). output (cupy.ndarray, optional): If given, the result of shifting the input is placed in this array. Returns: output (cupy.ndarray): The filtered output. """ ndim = input.ndim output = _get_output_fourier(output, input) axis = internal._normalize_axis_index(axis, ndim) sizes = _util._fix_sequence_arg(size, ndim, 'size') _core.elementwise_copy(input, output) for ax, (size, ax_size) in enumerate(zip(sizes, output.shape)): # compute the frequency grid in Hz if ax == axis and n > 0: arr = cupy.arange(ax_size, dtype=output.real.dtype) arr /= n else: arr = cupy.fft.fftfreq(ax_size) arr = arr.astype(output.real.dtype, copy=False) # compute the uniform filter weights arr *= size cupy.sinc(arr, out=arr) # reshape for broadcasting arr = _reshape_nd(arr, ndim=ndim, axis=ax) output *= arr return output def fourier_shift(input, shift, n=-1, axis=-1, output=None): """Multidimensional Fourier shift filter. The array is multiplied with the Fourier transform of a shift operation. Args: input (cupy.ndarray): The input array. This should be in the Fourier domain. shift (float or sequence of float): The size of shift. If a float, `shift` is the same for all axes. If a sequence, `shift` has to contain one value for each axis. n (int, optional): If `n` is negative (default), then the input is assumed to be the result of a complex fft. If `n` is larger than or equal to zero, the input is assumed to be the result of a real fft, and `n` gives the length of the array before transformation along the real transform direction. axis (int, optional): The axis of the real transform (only used when ``n > -1``). output (cupy.ndarray, optional): If given, the result of shifting the input is placed in this array. Returns: output (cupy.ndarray): The shifted output (in the Fourier domain). """ ndim = input.ndim output = _get_output_fourier(output, input, complex_only=True) axis = internal._normalize_axis_index(axis, ndim) shifts = _util._fix_sequence_arg(shift, ndim, 'shift') _core.elementwise_copy(input, output) for ax, (shiftk, ax_size) in enumerate(zip(shifts, output.shape)): if shiftk == 0: continue if ax == axis and n > 0: # cp.fft.rfftfreq(ax_size) * (-2j * numpy.pi * shiftk * ax_size/n) arr = cupy.arange(ax_size, dtype=output.dtype) arr *= -2j * numpy.pi * shiftk / n else: arr = cupy.fft.fftfreq(ax_size) arr = arr * (-2j * numpy.pi * shiftk) cupy.exp(arr, out=arr) # reshape for broadcasting arr = _reshape_nd(arr, ndim=ndim, axis=ax) output *= arr return output def fourier_ellipsoid(input, size, n=-1, axis=-1, output=None): """Multidimensional ellipsoid Fourier filter. The array is multiplied with the fourier transform of a ellipsoid of given sizes. Args: input (cupy.ndarray): The input array. size (float or sequence of float): The size of the box used for filtering. If a float, `size` is the same for all axes. If a sequence, `size` has to contain one value for each axis. n (int, optional): If `n` is negative (default), then the input is assumed to be the result of a complex fft. If `n` is larger than or equal to zero, the input is assumed to be the result of a real fft, and `n` gives the length of the array before transformation along the real transform direction. axis (int, optional): The axis of the real transform (only used when ``n > -1``). output (cupy.ndarray, optional): If given, the result of shifting the input is placed in this array. Returns: output (cupy.ndarray): The filtered output. """ ndim = input.ndim if ndim == 1: return fourier_uniform(input, size, n, axis, output) if ndim > 3: # Note: SciPy currently does not do any filtering on >=4d inputs, but # does not warn about this! raise NotImplementedError('Only 1d, 2d and 3d inputs are supported') output = _get_output_fourier(output, input) axis = internal._normalize_axis_index(axis, ndim) sizes = _util._fix_sequence_arg(size, ndim, 'size') _core.elementwise_copy(input, output) # compute the distance from the origin for all samples in Fourier space distance = 0 for ax, (size, ax_size) in enumerate(zip(sizes, output.shape)): # compute the frequency grid in Hz if ax == axis and n > 0: arr = cupy.arange(ax_size, dtype=output.real.dtype) arr *= numpy.pi * size / n else: arr = cupy.fft.fftfreq(ax_size) arr *= numpy.pi * size arr = arr.astype(output.real.dtype, copy=False) arr *= arr arr = _reshape_nd(arr, ndim=ndim, axis=ax) distance = distance + arr cupy.sqrt(distance, out=distance) if ndim == 2: special.j1(distance, out=output) output *= 2 output /= distance elif ndim == 3: cupy.sin(distance, out=output) output -= distance * cupy.cos(distance) output *= 3 output /= distance ** 3 output[(0,) * ndim] = 1.0 # avoid NaN in corner at frequency=0 location output *= input return output