Unverified Commit b8807922 authored by pfeatherstone's avatar pfeatherstone Committed by GitHub
Browse files

FFT: added std::vector overloads for fft, ifft, fft_inplace and ifft_inplace (#2286)



* [FFT] added fft, ifft, fft_inplace and ifft_inplace overloads for std::vector

* [FFT] 	- static_assert T is a floating point type. There are static asserts in mkl_fft and kiss_fft, but it doesn't hurt adding them in the matrix API too so users get helpful warnings higher up in the API.

* [FFT] 	- added documentation for std::vector overloads in matrix_fft_abstract.h file
Co-authored-by: default avatarpf <pf@pf-ubuntu-dev>
parent 044ff91b
...@@ -37,6 +37,25 @@ namespace dlib ...@@ -37,6 +37,25 @@ namespace dlib
return nc == 0 ? 0 : 2*(nc-1); return nc == 0 ? 0 : 2*(nc-1);
} }
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
matrix<std::complex<T>,0,1> fft (const std::vector<std::complex<T>, Alloc>& in)
{
//complex FFT
static_assert(std::is_floating_point<T>::value, "only support floating point types");
matrix<std::complex<T>,0,1> out(in.size());
if (in.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({(long)in.size()}, &in[0], &out(0,0), false);
#else
kiss_fft({(long)in.size()}, &in[0], &out(0,0), false);
#endif
}
return out;
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L > template < typename T, long NR, long NC, typename MM, typename L >
...@@ -67,6 +86,26 @@ namespace dlib ...@@ -67,6 +86,26 @@ namespace dlib
return fft(in); return fft(in);
} }
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
matrix<std::complex<T>,0,1> ifft (const std::vector<std::complex<T>, Alloc>& in)
{
//complex FFT
static_assert(std::is_floating_point<T>::value, "only support floating point types");
matrix<std::complex<T>,0,1> out(in.size());
if (in.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({(long)in.size()}, &in[0], &out(0,0), true);
#else
kiss_fft({(long)in.size()}, &in[0], &out(0,0), true);
#endif
out /= out.size();
}
return out;
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L > template < typename T, long NR, long NC, typename MM, typename L >
...@@ -159,12 +198,29 @@ namespace dlib ...@@ -159,12 +198,29 @@ namespace dlib
matrix<typename EXP::type> in(data); matrix<typename EXP::type> in(data);
return ifftr(in); return ifftr(in);
} }
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void fft_inplace (std::vector<std::complex<T>, Alloc>& data)
{
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({(long)data.size()}, &data[0], &data[0], false);
#else
kiss_fft({(long)data.size()}, &data[0], &data[0], false);
#endif
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L > template < typename T, long NR, long NC, typename MM, typename L >
void fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data) void fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
{ {
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0) if (data.size() != 0)
{ {
#ifdef DLIB_USE_MKL_FFT #ifdef DLIB_USE_MKL_FFT
...@@ -175,11 +231,28 @@ namespace dlib ...@@ -175,11 +231,28 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void ifft_inplace (std::vector<std::complex<T>, Alloc>& data)
{
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({(long)data.size()}, &data[0], &data[0], true);
#else
kiss_fft({(long)data.size()}, &data[0], &data[0], true);
#endif
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L > template < typename T, long NR, long NC, typename MM, typename L >
void ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data) void ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
{ {
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0) if (data.size() != 0)
{ {
#ifdef DLIB_USE_MKL_FFT #ifdef DLIB_USE_MKL_FFT
......
...@@ -60,6 +60,26 @@ namespace dlib ...@@ -60,6 +60,26 @@ namespace dlib
- ifft(D) == data - ifft(D) == data
!*/ !*/
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
matrix<std::complex<T>,0,1> fft (
const std::vector<std::complex<T>, Alloc>& data
);
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- Computes the 1 dimensional discrete Fourier transform of the given data
vector and returns it. In particular, we return a matrix D such that:
- D.nr() == data.size()
- D.nc() == 1
- D(0,0) == the DC term of the Fourier transform.
- starting with D(0,0), D contains progressively higher frequency components
of the input data.
- ifft(D) == data
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
...@@ -78,6 +98,24 @@ namespace dlib ...@@ -78,6 +98,24 @@ namespace dlib
- fft(D) == data - fft(D) == data
!*/ !*/
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
matrix<std::complex<T>,0,1> ifft (
const std::vector<std::complex<T>, Alloc>& data
)
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- Computes the 1 dimensional inverse discrete Fourier transform of the
given data vector and returns it. In particular, we return a matrix D such
that:
- D.nr() == data.size()
- D.nc() == 1
- fft(D) == data
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
...@@ -138,6 +176,21 @@ namespace dlib ...@@ -138,6 +176,21 @@ namespace dlib
- #data == fft(data) - #data == fft(data)
!*/ !*/
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void fft_inplace (
std::vector<std::complex<T>, Alloc>& data
)
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- This function is identical to fft() except that it does the FFT in-place.
That is, after this function executes we will have:
- #data == fft(data)
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -163,6 +216,22 @@ namespace dlib ...@@ -163,6 +216,22 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void ifft_inplace (
std::vector<std::complex<T>, Alloc>& data
);
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- This function is identical to ifft() except that it does the inverse FFT
in-place. That is, after this function executes we will have:
- #data == ifft(data)*data.size()
- Note that the output needs to be divided by data.size() to complete the
inverse transformation.
!*/
// ----------------------------------------------------------------------------------------
} }
#endif // DLIB_FFt_ABSTRACT_Hh_ #endif // DLIB_FFt_ABSTRACT_Hh_
......
...@@ -481,6 +481,76 @@ namespace ...@@ -481,6 +481,76 @@ namespace
} }
#endif #endif
template<typename R>
void test_vector_overload_outplace()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 5e-2;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
auto func = [&](long size)
{
if (++test % 100 == 0)
print_spinner();
const matrix<complex<R>,0,1> m1 = rand_complex<R>(size,1);
const matrix<complex<R>,0,1> f1 = fft(m1); //this fft uses the dlib::matrix overload
const std::vector<complex<R>> m1_v(m1.begin(), m1.end()); //target
const std::vector<complex<R>> f1_v(f1.begin(), f1.end()); //target
const matrix<complex<R>,0,1> f2 = fft(m1_v); //this fft uses the std::vector overload
const matrix<complex<R>,0,1> m2 = ifft(f1_v); //this ifft uses the std::vector overload
R diff = max(norm(f2 - f1));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (size) = (" << size << ")" << " type " << typelabel);
diff = max(norm(m2 - m1));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (size) = (" << size << ")" << " type " << typelabel);
};
for (long size = 1; size <= 64; size++)
func(size);
//some odd balls...
func(103); print_spinner();
func(123); print_spinner();
func(131); print_spinner();
}
template<typename R>
void test_vector_overload_inplace()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 5e-2;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
auto func = [&](long size)
{
if (++test % 100 == 0)
print_spinner();
matrix<complex<R>,0,1> m1 = rand_complex<R>(size,1);
std::vector<complex<R>> m1_v(m1.begin(), m1.end());
fft_inplace(m1);
fft_inplace(m1_v);
R diff = max(norm(m1 - mat(m1_v)));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (size) = (" << size << ")" << " type " << typelabel);
};
for (long size = 1; size <= 64; size++)
func(size);
//some odd balls...
func(103); print_spinner();
func(123); print_spinner();
func(131); print_spinner();
}
class test_fft : public tester class test_fft : public tester
{ {
public: public:
...@@ -511,6 +581,10 @@ namespace ...@@ -511,6 +581,10 @@ namespace
test_kiss_vs_mkl<float>(); test_kiss_vs_mkl<float>();
test_kiss_vs_mkl<double>(); test_kiss_vs_mkl<double>();
#endif #endif
test_vector_overload_outplace<float>();
test_vector_overload_outplace<double>();
test_vector_overload_inplace<float>();
test_vector_overload_inplace<double>();
} }
} a; } a;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment