Unverified Commit 39eec294 authored by pfeatherstone's avatar pfeatherstone Committed by GitHub
Browse files

Arbitrary sized FFTs using modified kissFFT as default backend and MKL otherwise (#2253)



* [FFT] added kissfft wrappers, moved kiss and mkl wrappers into separate files, call the right functions in matrix_fft.h
Co-authored-by: default avatarpf <pf@pf-ubuntu-dev>
Co-authored-by: default avatarDavis King <davis@dlib.net>
parent 71dd9a6c
...@@ -86,6 +86,7 @@ if (UNIX OR MINGW) ...@@ -86,6 +86,7 @@ if (UNIX OR MINGW)
if (SIZE_OF_VOID_PTR EQUAL 8) if (SIZE_OF_VOID_PTR EQUAL 8)
set( mkl_search_path set( mkl_search_path
/opt/intel/oneapi/mkl/latest/lib/intel64
/opt/intel/mkl/*/lib/em64t /opt/intel/mkl/*/lib/em64t
/opt/intel/mkl/lib/intel64 /opt/intel/mkl/lib/intel64
/opt/intel/lib/intel64 /opt/intel/lib/intel64
...@@ -99,6 +100,7 @@ if (UNIX OR MINGW) ...@@ -99,6 +100,7 @@ if (UNIX OR MINGW)
mark_as_advanced(mkl_intel) mark_as_advanced(mkl_intel)
else() else()
set( mkl_search_path set( mkl_search_path
/opt/intel/oneapi/mkl/latest/lib/ia32
/opt/intel/mkl/*/lib/32 /opt/intel/mkl/*/lib/32
/opt/intel/mkl/lib/ia32 /opt/intel/mkl/lib/ia32
/opt/intel/lib/ia32 /opt/intel/lib/ia32
...@@ -114,6 +116,7 @@ if (UNIX OR MINGW) ...@@ -114,6 +116,7 @@ if (UNIX OR MINGW)
# Get mkl_include_dir # Get mkl_include_dir
set(mkl_include_search_path set(mkl_include_search_path
/opt/intel/oneapi/mkl/latest/include
/opt/intel/mkl/include /opt/intel/mkl/include
/opt/intel/include /opt/intel/include
) )
......
#ifndef DLIB_FFT_SIZE_H
#define DLIB_FFT_SIZE_H
#include <array>
#include <algorithm>
#include <numeric>
#include "../assert.h"
#include "../hash.h"
namespace dlib
{
class fft_size
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is a container used to store the dimensions of an FFT
operation. It is implemented as a stack-based container with an
upper bound of 5 dimensions (batch,channels,height,width,depth).
All dimensions must be strictly positive.
The object is either default constructed, constructed with an
initialiser list or with a pair of iterators
If default-constructed, the object is empty and in an invalid state.
That is, FFT functions will throw if attempted to be used with such
an object.
If constructed with an initialiser list L, the object is properly
initialised provided:
- L.size() > 0 and L.size() <= 5
- L contains strictly positive values
If constructed with a pair of iterators, the behaviour of the
constructor is exactly the same as if constructed with an
initializer list spanned by those iterators.
Once the object is constructed, it is immutable.
!*/
public:
using container_type = std::array<long,5>;
using const_reference = container_type::const_reference;
using iterator = container_type::iterator;
using const_iterator = container_type::const_iterator;
fft_size() = default;
/*!
ensures
- *this is properly initialised
- num_dims() == 0
!*/
template<typename ConstIterator>
fft_size(ConstIterator dims_begin, ConstIterator dims_end)
/*!
requires
- ConstIterator is const iterator type that points to a long object
- std::distance(dims_begin, dims_end) > 0
- std::distance(dims_begin, dims_end) <= 5
- range contains strictly positive values
ensures
- *this is properly initialised
- num_dims() == std::distance(dims_begin, dims_end)
- num_elements() == product of all values in range
!*/
{
const size_t ndims = std::distance(dims_begin, dims_end);
DLIB_ASSERT(ndims > 0, "fft_size objects must be non-empty");
DLIB_ASSERT(ndims <= _dims.size(), "fft_size objects must have size less than 6");
DLIB_ASSERT(std::find_if(dims_begin, dims_end, [](long dim) {return dim <= 0;}) == dims_end, "fft_size objects must contain strictly positive values");
std::copy(dims_begin, dims_end, _dims.begin());
_size = ndims;
_num_elements = std::accumulate(dims_begin, dims_end, 1, std::multiplies<long>());
}
fft_size(std::initializer_list<long> dims)
: fft_size(dims.begin(), dims.end())
/*!
requires
- dims.size() > 0 and dims.size() <= 5
- dims contains strictly positive values
ensures
- *this is properly initialised
- num_dims() == dims.size()
- num_elements() == product of all values in dims
!*/
{
}
size_t num_dims() const
/*!
ensures
- returns the number of dimensions
!*/
{
return _size;
}
long num_elements() const
/*!
ensures
- if num_dims() > 0, returns the product of all dimensions, i.e. the total number
of elements
- if num_dims() == 0, returns 0
!*/
{
return _num_elements;
}
const_reference operator[](size_t index) const
/*!
requires
- index < num_dims()
ensures
- returns a const reference to the dimension at position index
!*/
{
DLIB_ASSERT(index < _size, "index " << index << " out of range [0," << _size << ")");
return _dims[index];
}
const_reference back() const
/*!
requires
- num_dims() > 0
ensures
- returns a const reference to (*this)[num_dims()-1]
!*/
{
DLIB_ASSERT(_size > 0, "object is empty");
return _dims[_size-1];
}
const_iterator begin() const
/*!
ensures
- returns a const iterator that points to the first dimension
in this container or end() if the array is empty.
!*/
{
return _dims.begin();
}
const_iterator end() const
/*!
ensures
- returns a const iterator that points to one past the end of
the container.
!*/
{
return _dims.begin() + _size;
}
bool operator==(const fft_size& other) const
/*!
ensures
- returns true if two fft_size objects have same size and same dimensions, i.e. if they have identical states
!*/
{
return this->_size == other._size && std::equal(begin(), end(), other.begin());
}
private:
size_t _size = 0;
size_t _num_elements = 0;
container_type _dims;
};
inline dlib::uint32 hash(
const fft_size& item,
dlib::uint32 seed = 0)
{
seed = dlib::hash((dlib::uint64)item.num_dims(), seed);
seed = std::accumulate(item.begin(), item.end(), seed, [](dlib::uint32 seed, long next) {
return dlib::hash((dlib::uint64)next, seed);
});
return seed;
}
/*!
ensures
- returns a 32bit hash of the data stored in item.
!*/
inline fft_size pop_back(const fft_size& size)
{
DLIB_ASSERT(size.num_dims() > 0);
return fft_size(size.begin(), size.end() - 1);
}
/*!
requires
- num_dims.size() > 0
ensures
- returns a copy of size with the last dimension removed.
!*/
inline fft_size squeeze_ones(const fft_size size)
{
DLIB_ASSERT(size.num_dims() > 0);
fft_size newsize;
if (size.num_elements() == 1)
{
newsize = {1};
}
else
{
fft_size::container_type tmp;
auto end = std::copy_if(size.begin(), size.end(), tmp.begin(), [](long dim){return dim != 1;});
newsize = fft_size(tmp.begin(), end);
}
return newsize;
}
/*!
requires
- num_dims.size() > 0
ensures
- removes dimensions with values equal to 1, yielding a new fft_size object with the same num_elements() but fewer dimensions
!*/
}
#endif //DLIB_FFT_SIZE_H
This diff is collapsed.
This diff is collapsed.
...@@ -11,8 +11,8 @@ namespace dlib ...@@ -11,8 +11,8 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
bool is_power_of_two ( constexpr bool is_power_of_two (
const unsigned long& value const unsigned long value
); );
/*! /*!
ensures ensures
...@@ -21,7 +21,27 @@ namespace dlib ...@@ -21,7 +21,27 @@ namespace dlib
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
constexpr long fftr_nc_size(
long nc
);
/*!
ensures
- returns the output dimension of a 1D real FFT
!*/
// ----------------------------------------------------------------------------------------
constexpr long ifftr_nc_size(
long nc
);
/*!
ensures
- returns the output dimension of an inverse 1D real FFT
!*/
// ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
typename EXP::matrix_type fft ( typename EXP::matrix_type fft (
const matrix_exp<EXP>& data const matrix_exp<EXP>& data
...@@ -29,8 +49,6 @@ namespace dlib ...@@ -29,8 +49,6 @@ namespace dlib
/*! /*!
requires requires
- data contains elements of type std::complex<> that itself contains double, float, or long double. - data contains elements of type std::complex<> that itself contains double, float, or long double.
- is_power_of_two(data.nr()) == true
- is_power_of_two(data.nc()) == true
ensures ensures
- Computes the 1 or 2 dimensional discrete Fourier transform of the given data - Computes the 1 or 2 dimensional discrete Fourier transform of the given data
matrix and returns it. In particular, we return a matrix D such that: matrix and returns it. In particular, we return a matrix D such that:
...@@ -39,9 +57,9 @@ namespace dlib ...@@ -39,9 +57,9 @@ namespace dlib
- D(0,0) == the DC term of the Fourier transform. - D(0,0) == the DC term of the Fourier transform.
- starting with D(0,0), D contains progressively higher frequency components - starting with D(0,0), D contains progressively higher frequency components
of the input data. of the input data.
- ifft(D) == D - ifft(D) == data
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
...@@ -51,8 +69,6 @@ namespace dlib ...@@ -51,8 +69,6 @@ namespace dlib
/*! /*!
requires requires
- data contains elements of type std::complex<> that itself contains double, float, or long double. - data contains elements of type std::complex<> that itself contains double, float, or long double.
- is_power_of_two(data.nr()) == true
- is_power_of_two(data.nc()) == true
ensures ensures
- Computes the 1 or 2 dimensional inverse discrete Fourier transform of the - Computes the 1 or 2 dimensional inverse discrete Fourier transform of the
given data vector and returns it. In particular, we return a matrix D such given data vector and returns it. In particular, we return a matrix D such
...@@ -62,8 +78,47 @@ namespace dlib ...@@ -62,8 +78,47 @@ namespace dlib
- fft(D) == data - fft(D) == data
!*/ !*/
// ----------------------------------------------------------------------------------------
template <typename EXP>
matrix<add_complex_t<typename EXP::type>> fftr (
const matrix_exp<EXP>& data
);
/*!
requires
- data contains elements of type double, float, or long double.
- data.nc() is even
ensures
- Computes the 1 or 2 dimensional real discrete Fourier transform of the given data
matrix and returns it. In particular, we return a matrix D such that:
- D.nr() == data.nr()
- D.nc() == fftr_nc_size(data.nc())
- D(0,0) == the DC term of the Fourier transform.
- starting with D(0,0), D contains progressively higher frequency components
of the input data.
- ifftr(D) == data
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename EXP>
matrix<remove_complex_t<typename EXP::type>> ifftr (
const matrix_exp<EXP>& data
);
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- Computes the 1 or 2 dimensional inverse real discrete Fourier transform of the
given data vector and returns it. In particular, we return a matrix D such
that:
- D.nr() == data.nr()
- D.nc() == ifftr_nc_size(data.nc())
- fftr(D) == data
!*/
// ----------------------------------------------------------------------------------------
template < template <
typename T, typename T,
long NR, long NR,
...@@ -77,8 +132,6 @@ namespace dlib ...@@ -77,8 +132,6 @@ namespace dlib
/*! /*!
requires requires
- data contains elements of type std::complex<> that itself contains double, float, or long double. - data contains elements of type std::complex<> that itself contains double, float, or long double.
- is_power_of_two(data.nr()) == true
- is_power_of_two(data.nc()) == true
ensures ensures
- This function is identical to fft() except that it does the FFT in-place. - This function is identical to fft() except that it does the FFT in-place.
That is, after this function executes we will have: That is, after this function executes we will have:
...@@ -100,8 +153,6 @@ namespace dlib ...@@ -100,8 +153,6 @@ namespace dlib
/*! /*!
requires requires
- data contains elements of type std::complex<> that itself contains double, float, or long double. - data contains elements of type std::complex<> that itself contains double, float, or long double.
- is_power_of_two(data.nr()) == true
- is_power_of_two(data.nc()) == true
ensures ensures
- This function is identical to ifft() except that it does the inverse FFT - This function is identical to ifft() except that it does the inverse FFT
in-place. That is, after this function executes we will have: in-place. That is, after this function executes we will have:
......
...@@ -48,6 +48,40 @@ namespace dlib ...@@ -48,6 +48,40 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
/*!A remove_complex
This is a template that can be used to remove std::complex from the underlying type.
For example:
remove_complex<float>::type == float
remove_complex<std::complex<float> >::type == float
!*/
template <typename T>
struct remove_complex {typedef T type;};
template <typename T>
struct remove_complex<std::complex<T> > {typedef T type;};
template<typename T>
using remove_complex_t = typename remove_complex<T>::type;
// ----------------------------------------------------------------------------------------
/*!A add_complex
This is a template that can be used to add std::complex to the underlying type if it isn't already complex.
For example:
add_complex<float>::type == std::complex<float>
add_complex<std::complex<float> >::type == std::complex<float>
!*/
template <typename T>
struct add_complex {typedef std::complex<T> type;};
template <typename T>
struct add_complex<std::complex<T> > {typedef std::complex<T> type;};
template<typename T>
using add_complex_t = typename add_complex<T>::type;
// ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
inline bool is_row_vector ( inline bool is_row_vector (
const matrix_exp<EXP>& m const matrix_exp<EXP>& m
......
#ifndef DLIB_MKL_FFT_H
#define DLIB_MKL_FFT_H
#include <type_traits>
#include <mkl_dfti.h>
#include "fft_size.h"
#define DLIB_DFTI_CHECK_STATUS(s) \
if((s) != 0 && !DftiErrorClass((s), DFTI_NO_ERROR)) \
{ \
throw dlib::error(DftiErrorMessage((s))); \
}
namespace dlib
{
template<typename T>
void mkl_fft(const fft_size& dims, const std::complex<T>* in, std::complex<T>* out, bool is_inverse)
/*!
requires
- T must be either float or double
- dims represents the dimensions of both `in` and `out`
- dims.num_dims() > 0
- dims.num_dims() < 3
ensures
- performs an FFT on `in` and stores the result in `out`.
- if `is_inverse` is true, a backward FFT is performed,
otherwise a forward FFT is performed.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_COMPLEX, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_COMPLEX, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = size[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
const DFTI_CONFIG_VALUE inplacefft = in == out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
if (is_inverse)
status = DftiComputeBackward(h, (void*)in, (void*)out);
else
status = DftiComputeForward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
/*
* in has dims[0] * dims[1] * ... * dims[-2] * dims[-1] points
* out has dims[0] * dims[1] * ... * dims[-2] * (dims[-1]/2+1) points
*/
template<typename T>
void mkl_fftr(const fft_size& dims, const T* in, std::complex<T>* out)
/*!
requires
- T must be either float or double
- dims represent the dimensions of `in`
- `out` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.num_dims() <= 3
- dims.back() must be even
ensures
- performs a real FFT on `in` and stores the result in `out`.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
DLIB_ASSERT(dims.back() % 2 == 0, "last dimension needs to be even");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
const long lastdim = dims[1]/2+1;
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = size[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = lastdim;
strides[2] = 1;
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
}
const DFTI_CONFIG_VALUE inplacefft = (void*)in == (void*)out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiComputeForward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
/*
* in has dims[0] * dims[1] * ... * dims[-2] * (dims[-1]/2+1) points
* out has dims[0] * dims[1] * ... * dims[-2] * dims[-1] points
*/
template<typename T>
void mkl_ifftr(const fft_size& dims, const std::complex<T>* in, T* out)
/*!
requires
- T must be either float or double
- dims represent the dimensions of `out`
- `in` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.num_dims() <= 3
- dims.back() must be even
ensures
- performs an inverse real FFT on `in` and stores the result in `out`.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
DLIB_ASSERT(dims.back() % 2 == 0, "last dimension needs to be even");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
const long lastdim = dims[1]/2+1;
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = lastdim;
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = dims[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
}
const DFTI_CONFIG_VALUE inplacefft = (void*)in == (void*)out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiComputeBackward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
}
#endif // DLIB_MKL_FFT_H
...@@ -11,7 +11,12 @@ ...@@ -11,7 +11,12 @@
#include <dlib/compress_stream.h> #include <dlib/compress_stream.h>
#include <dlib/base64.h> #include <dlib/base64.h>
#ifdef DLIB_USE_MKL_FFT
#include <dlib/matrix/kiss_fft.h>
#include <dlib/matrix/mkl_fft.h>
#endif
#include "tester.h" #include "tester.h"
#include "fftr_good_data.h"
namespace namespace
{ {
...@@ -21,19 +26,35 @@ namespace ...@@ -21,19 +26,35 @@ namespace
using namespace std; using namespace std;
logger dlog("test.fft"); logger dlog("test.fft");
static dlib::rand rnd(10000);
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
matrix<complex<double> > rand_complex(long nr, long nc) template<typename R>
matrix<complex<R> > rand_complex(long nr, long nc, R scale = 10.0)
{
matrix<complex<R> > m(nr,nc);
for (long r = 0; r < m.nr(); ++r)
{
for (long c = 0; c < m.nc(); ++c)
{
m(r,c) = std::complex<R>(rnd.get_random_gaussian() * scale, rnd.get_random_gaussian() * scale);
}
}
return m;
}
template<typename R>
matrix<R> rand_real(long nr, long nc)
{ {
static dlib::rand rnd; matrix<R> m(nr,nc);
matrix<complex<double> > m(nr,nc);
for (long r = 0; r < m.nr(); ++r) for (long r = 0; r < m.nr(); ++r)
{ {
for (long c = 0; c < m.nc(); ++c) for (long c = 0; c < m.nc(); ++c)
{ {
m(r,c) = complex<double>(rnd.get_random_gaussian()*10, rnd.get_random_gaussian()*10); m(r,c) = rnd.get_random_gaussian() * 10.0;
} }
} }
return m; return m;
...@@ -65,35 +86,78 @@ namespace ...@@ -65,35 +86,78 @@ namespace
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void test_against_saved_good_fftrs()
{
std::stringstream base64_in, decompressed_in, decompressed_out;
dlib::base64 base64_coder;
dlib::compress_stream::kernel_1ea compressor;
base64_in = get_fftr_stringstream();
base64_coder.decode(base64_in, decompressed_in);
compressor.decompress(decompressed_in, decompressed_out);
matrix<double> m1;
matrix<complex<double>> m2;
while (decompressed_out.peek() != EOF)
{
print_spinner();
deserialize(m1,decompressed_out);
deserialize(m2,decompressed_out);
DLIB_TEST(max(norm(fftr(m1)-m2)) < 1e-16);
DLIB_TEST(max(squared(m1-ifftr(m2))) < 1e-16);
}
}
// ----------------------------------------------------------------------------------------
void test_random_ffts() void test_random_ffts()
{ {
for (int iter = 0; iter < 10; ++iter) int test = 0;
for (int nr = 1; nr <= 64; nr++)
{ {
print_spinner(); for (int nc = 1; nc <= 64; nc++)
for (int nr = 1; nr <= 128; nr*=2)
{ {
for (int nc = 1; nc <= 128; nc *= 2) if (++test % 100 == 0)
{ print_spinner();
const matrix<complex<double> > m1 = rand_complex(nr,nc);
const matrix<complex<float> > fm1 = matrix_cast<complex<float> >(rand_complex(nr,nc)); const matrix<complex<double> > m1 = rand_complex<double>(nr,nc);
const matrix<complex<float> > fm1 = rand_complex<float>(nr,nc);
DLIB_TEST(max(norm(ifft(fft(m1))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(fm1))-fm1)) < 1e-7); DLIB_TEST(max(norm(ifft(fft(m1))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(fm1))-fm1)) < 1e-7);
matrix<complex<double> > temp = m1;
matrix<complex<float> > ftemp = fm1; matrix<complex<double> > temp = m1;
fft_inplace(temp); matrix<complex<float> > ftemp = fm1;
fft_inplace(ftemp); fft_inplace(temp);
DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16); fft_inplace(ftemp);
DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7); DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
ifft_inplace(temp); DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
ifft_inplace(ftemp); ifft_inplace(temp);
DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16); ifft_inplace(ftemp);
DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7); DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
} DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
} }
} }
{
// test size 0 matrices.
matrix<complex<double>> temp;
matrix<complex<float>> ftemp;
fft_inplace(temp);
fft_inplace(ftemp);
DLIB_TEST(temp.size() == 0);
DLIB_TEST(ftemp.size() == 0);
DLIB_TEST(fft(temp).size() == 0);
DLIB_TEST(ifft(temp).size() == 0);
matrix<double> rtemp;
DLIB_TEST(fftr(rtemp).size() == 0);
DLIB_TEST(ifftr(temp).size() == 0);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -102,8 +166,8 @@ namespace ...@@ -102,8 +166,8 @@ namespace
void test_real_compile_time_sized_ffts() void test_real_compile_time_sized_ffts()
{ {
print_spinner(); print_spinner();
const matrix<complex<double>,nr,nc> m1 = complex_matrix(real(rand_complex(nr,nc))); const matrix<complex<double>,nr,nc> m1 = complex_matrix(rand_real<double>(nr,nc));
const matrix<complex<float>,nr,nc> fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(nr,nc)))); const matrix<complex<float>,nr,nc> fm1 = complex_matrix(rand_real<float>(nr,nc));
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16); DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7); DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7);
...@@ -122,40 +186,302 @@ namespace ...@@ -122,40 +186,302 @@ namespace
void test_random_real_ffts() void test_random_real_ffts()
{ {
for (int iter = 0; iter < 10; ++iter) int test = 0;
for (int nr = 1; nr <= 64; nr++)
{ {
print_spinner(); for (int nc = 1; nc <= 64; nc++)
for (int nr = 1; nr <= 128; nr*=2)
{ {
for (int nc = 1; nc <= 128; nc *= 2) if (++test % 100 == 0)
{ print_spinner();
const matrix<complex<double> > m1 = complex_matrix(real(rand_complex(nr,nc)));
const matrix<complex<float> > fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(nr,nc)))); const matrix<complex<double> > m1 = complex_matrix(rand_real<double>(nr,nc));
const matrix<complex<float> > fm1 = complex_matrix(rand_real<float>(nr,nc));
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7); DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7);
matrix<complex<double> > temp = m1;
matrix<complex<float> > ftemp = fm1; matrix<complex<double> > temp = m1;
fft_inplace(temp); matrix<complex<float> > ftemp = fm1;
fft_inplace(ftemp); fft_inplace(temp);
DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16); fft_inplace(ftemp);
DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7); DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
ifft_inplace(temp); DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
ifft_inplace(ftemp); ifft_inplace(temp);
DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16); ifft_inplace(ftemp);
DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7); DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
} DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
} }
} }
test_real_compile_time_sized_ffts<16,16>(); test_real_compile_time_sized_ffts<16,16>();
test_real_compile_time_sized_ffts<16,1>(); test_real_compile_time_sized_ffts<16,1>();
test_real_compile_time_sized_ffts<1,16>(); test_real_compile_time_sized_ffts<1,16>();
test_real_compile_time_sized_ffts<480,480>(); //2^5 * 3 * 5
test_real_compile_time_sized_ffts<480,1>(); //2^5 * 3 * 5
test_real_compile_time_sized_ffts<1,480>(); //2^5 * 3 * 5
test_real_compile_time_sized_ffts<131,131>(); //some large prime
test_real_compile_time_sized_ffts<131,1>(); //some large prime
test_real_compile_time_sized_ffts<1,131>(); //some large prime
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template<typename R>
void test_linearity_complex()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 5e-2;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
auto func = [&](long nr, long nc)
{
if (++test % 100 == 0)
print_spinner();
const matrix<complex<R>> m1 = rand_complex<R>(nr,nc);
const matrix<complex<R>> m2 = rand_complex<R>(nr,nc);
const R a1 = rnd.get_double_in_range(-10.0, 10.0);
const R a2 = rnd.get_double_in_range(-10.0, 10.0);
const matrix<complex<R>> m3 = a1*m1 + a2*m2;
const matrix<complex<R>> f1 = fft(m1);
const matrix<complex<R>> f2 = fft(m2);
const matrix<complex<R>> f3 = fft(m3);
R diff = max(norm(f3 - a1*f1 - a2*f2));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (nr,nc) = (" << nr << "," << nc << ")" << " type " << typelabel);
const matrix<complex<R>> m4 = ifft(f3);
diff = max(norm(m4 - m3));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (nr,nc) = (" << nr << "," << nc << ")" << " type " << typelabel);
};
for (int nr = 1; nr <= 64; nr++)
{
for (int nc = 1; nc <= 64; nc++)
{
func(nr,nc);
}
}
//some odd balls...
func(3, 131); print_spinner();
func(123, 103); print_spinner();
}
// ----------------------------------------------------------------------------------------
template<typename R>
void test_linearity_real()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 1e-3;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
auto func = [&](long nr, long nc)
{
if (++test % 100 == 0)
print_spinner();
const matrix<R> m1 = rand_real<R>(nr,nc);
const matrix<R> m2 = rand_real<R>(nr,nc);
const R a1 = rnd.get_double_in_range(-10.0, 10.0);
const R a2 = rnd.get_double_in_range(-10.0, 10.0);
const matrix<R> m3 = a1*m1 + a2*m2;
const matrix<complex<R>> f1 = fftr(m1);
const matrix<complex<R>> f2 = fftr(m2);
const matrix<complex<R>> f3 = fftr(m3);
DLIB_TEST(f1.nr() == m1.nr());
DLIB_TEST(f1.nc() == fftr_nc_size(m1.nc()));
R diff = max(norm(f3 - a1*f1 - a2*f2));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (nr,nc) = (" << nr << "," << nc << ")" << " type " << typelabel);
const matrix<R> m4 = ifftr(f3);
DLIB_TEST(m4.nr() == f3.nr());
DLIB_TEST(m4.nc() == ifftr_nc_size(f3.nc()));
diff = max(squared(m4 - m3));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (nr,nc) = (" << nr << "," << nc << ")" << " type " << typelabel);
};
for (int nr = 2; nr <= 64; nr += 2)
{
for (int nc = 2; nc <= 64; nc += 2)
{
func(nr,nc);
}
}
//some odd balls...
func(89, 102); print_spinner();
func(123, 48); print_spinner();
}
// ----------------------------------------------------------------------------------------
template<typename R>
void test_kronecker_delta_impulse_response()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 1e-3;
int test = 0;
auto func = [&](long nr, long nc)
{
if (++test % 100 == 0)
print_spinner();
matrix<R> ones = dlib::ones_matrix<R>(nr,nc);
matrix<complex<R>> x = dlib::zeros_matrix<complex<R>>(nr,nc);
x(0,0) = 1.0f;
matrix<complex<R>> f = fft(x);
R diff_real = max(squared(real(f) - ones));
R diff_imag = max(squared(imag(f)));
DLIB_TEST(diff_real < tol);
DLIB_TEST(diff_imag < tol);
};
for (int nr = 1; nr <= 64; nr++)
{
for (int nc = 1; nc <= 64; nc++)
{
func(nr,nc);
}
}
//some odd balls...
func(3, 131); print_spinner();
func(123, 103); print_spinner();
}
// ----------------------------------------------------------------------------------------
template<typename R>
void test_time_shift()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 1e-1;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
auto func = [&](size_t size, size_t time_shift)
{
if (++test % 100 == 0)
print_spinner();
matrix<complex<R>> x1 = rand_complex<R>(1,size, 1.0);
matrix<complex<R>> x2 = x1;
std::rotate(x2.begin(), x2.begin() + time_shift, x2.end());
matrix<complex<R>> f1 = fft(x1);
matrix<complex<R>> f2 = fft(x2);
matrix<complex<R>> f2_expected = f1;
for (long i = 0 ; i < f1.size() ; i++)
f2_expected(i) = f1(i)*std::polar<R>(R(1), R(6.283185307179586476925286766559005768394338798*time_shift*i / size));
const auto diff_real = max(squared(real(f2) - real(f2_expected)));
const auto diff_imag = max(squared(imag(f2) - imag(f2_expected)));
DLIB_TEST_MSG(diff_real < tol, typelabel << " diff_real " << diff_real << " size " << size << " shift " << time_shift);
DLIB_TEST_MSG(diff_imag < tol, typelabel << " diff_real " << diff_imag << " size " << size << " shift " << time_shift);
};
for (size_t size = 16 ; size < 64 ; size++)
{
for (size_t time_shift = 10 ; time_shift < size/2 + 1 ; time_shift += 10)
{
func(size, time_shift);
}
}
//some odd balls...
func(3,1);
func(123,16);
func(123,122);
}
// ----------------------------------------------------------------------------------------
template<typename R>
void test_fftr_conjugacy_1D()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 1e-6;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
auto func = [&](long nc)
{
print_spinner();
matrix<R> m1 = rand_real<R>(1, nc);
matrix<complex<R>> f1 = fftr(m1);
matrix<complex<R>> f2 = fft(complex_matrix(m1));
matrix<complex<R>> f3 = join_rows(f1, conj(fliplr(colm(f1,range(1,f1.nc()-2)))));
const R diff = max(norm(f2-f3));
DLIB_TEST_MSG(diff < tol, typelabel << " diff " << diff << " nr " << m1.nr() << " nc " << m1.nc() << " tol " << tol);
};
//don't start from 2, as that is a special case where fft and fftr
//give the same number of columns.
//Therefore, fiplr(colm(f1,range(1,f1.nc()-2))) wouldn't work
for (long nc = 4 ; nc <= 128 ; nc+=2)
{
func(nc);
}
//some odd balls...
func(480);
func(130);
}
// ----------------------------------------------------------------------------------------
#ifdef DLIB_USE_MKL_FFT
template<typename R>
void test_kiss_vs_mkl()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-2 : 1e-2;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
for (int nr = 2; nr <= 64; nr += 2)
{
for (int nc = 2; nc <= 64; nc += 2)
{
if (++test % 100 == 0)
print_spinner();
std::vector<float> x1(nr*nc), y1(nr*nc), y2(nr*nc);
std::vector<std::complex<float>> f1(nr*(nc/2+1)), f2(nr*(nc/2+1));
for (int i = 0 ; i < (nr*nc) ; i++)
x1[i] = rnd.get_random_gaussian();
kiss_fftr({nr,nc}, &x1[0], &f1[0]);
mkl_fftr({nr,nc}, &x1[0], &f2[0]);
const R diff1 = max(norm(mat(f1) - mat(f2)));
DLIB_TEST_MSG(diff1 < tol, typelabel << " diff1 " << diff1 << " nr " << nr << " nc " << nc);
kiss_ifftr({nr,nc}, &f1[0], &y1[0]);
mkl_ifftr({nr,nc}, &f2[0], &y2[0]);
const R diff2 = max(squared(mat(y1) - mat(y2)));
DLIB_TEST_MSG(diff2 < tol, typelabel << " diff2 " << diff2 << " nr " << nr << " nc " << nc);
}
}
}
#endif
class test_fft : public tester class test_fft : public tester
{ {
public: public:
...@@ -169,8 +495,23 @@ namespace ...@@ -169,8 +495,23 @@ namespace
) )
{ {
test_against_saved_good_ffts(); test_against_saved_good_ffts();
test_against_saved_good_fftrs();
test_random_ffts(); test_random_ffts();
test_random_real_ffts(); test_random_real_ffts();
test_linearity_real<float>();
test_linearity_real<double>();
test_linearity_complex<float>();
test_linearity_complex<double>();
test_kronecker_delta_impulse_response<float>();
test_kronecker_delta_impulse_response<double>();
test_time_shift<float>();
test_time_shift<double>();
test_fftr_conjugacy_1D<float>();
test_fftr_conjugacy_1D<double>();
#ifdef DLIB_USE_MKL_FFT
test_kiss_vs_mkl<float>();
test_kiss_vs_mkl<double>();
#endif
} }
} a; } a;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment