Unverified Commit 39eec294 authored by pfeatherstone's avatar pfeatherstone Committed by GitHub
Browse files

Arbitrary sized FFTs using modified kissFFT as default backend and MKL otherwise (#2253)



* [FFT] added kissfft wrappers, moved kiss and mkl wrappers into separate files, call the right functions in matrix_fft.h
Co-authored-by: default avatarpf <pf@pf-ubuntu-dev>
Co-authored-by: default avatarDavis King <davis@dlib.net>
parent 71dd9a6c
...@@ -86,6 +86,7 @@ if (UNIX OR MINGW) ...@@ -86,6 +86,7 @@ if (UNIX OR MINGW)
if (SIZE_OF_VOID_PTR EQUAL 8) if (SIZE_OF_VOID_PTR EQUAL 8)
set( mkl_search_path set( mkl_search_path
/opt/intel/oneapi/mkl/latest/lib/intel64
/opt/intel/mkl/*/lib/em64t /opt/intel/mkl/*/lib/em64t
/opt/intel/mkl/lib/intel64 /opt/intel/mkl/lib/intel64
/opt/intel/lib/intel64 /opt/intel/lib/intel64
...@@ -99,6 +100,7 @@ if (UNIX OR MINGW) ...@@ -99,6 +100,7 @@ if (UNIX OR MINGW)
mark_as_advanced(mkl_intel) mark_as_advanced(mkl_intel)
else() else()
set( mkl_search_path set( mkl_search_path
/opt/intel/oneapi/mkl/latest/lib/ia32
/opt/intel/mkl/*/lib/32 /opt/intel/mkl/*/lib/32
/opt/intel/mkl/lib/ia32 /opt/intel/mkl/lib/ia32
/opt/intel/lib/ia32 /opt/intel/lib/ia32
...@@ -114,6 +116,7 @@ if (UNIX OR MINGW) ...@@ -114,6 +116,7 @@ if (UNIX OR MINGW)
# Get mkl_include_dir # Get mkl_include_dir
set(mkl_include_search_path set(mkl_include_search_path
/opt/intel/oneapi/mkl/latest/include
/opt/intel/mkl/include /opt/intel/mkl/include
/opt/intel/include /opt/intel/include
) )
......
#ifndef DLIB_FFT_SIZE_H
#define DLIB_FFT_SIZE_H
#include <array>
#include <algorithm>
#include <numeric>
#include "../assert.h"
#include "../hash.h"
namespace dlib
{
class fft_size
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is a container used to store the dimensions of an FFT
operation. It is implemented as a stack-based container with an
upper bound of 5 dimensions (batch,channels,height,width,depth).
All dimensions must be strictly positive.
The object is either default constructed, constructed with an
initialiser list or with a pair of iterators
If default-constructed, the object is empty and in an invalid state.
That is, FFT functions will throw if attempted to be used with such
an object.
If constructed with an initialiser list L, the object is properly
initialised provided:
- L.size() > 0 and L.size() <= 5
- L contains strictly positive values
If constructed with a pair of iterators, the behaviour of the
constructor is exactly the same as if constructed with an
initializer list spanned by those iterators.
Once the object is constructed, it is immutable.
!*/
public:
using container_type = std::array<long,5>;
using const_reference = container_type::const_reference;
using iterator = container_type::iterator;
using const_iterator = container_type::const_iterator;
fft_size() = default;
/*!
ensures
- *this is properly initialised
- num_dims() == 0
!*/
template<typename ConstIterator>
fft_size(ConstIterator dims_begin, ConstIterator dims_end)
/*!
requires
- ConstIterator is const iterator type that points to a long object
- std::distance(dims_begin, dims_end) > 0
- std::distance(dims_begin, dims_end) <= 5
- range contains strictly positive values
ensures
- *this is properly initialised
- num_dims() == std::distance(dims_begin, dims_end)
- num_elements() == product of all values in range
!*/
{
const size_t ndims = std::distance(dims_begin, dims_end);
DLIB_ASSERT(ndims > 0, "fft_size objects must be non-empty");
DLIB_ASSERT(ndims <= _dims.size(), "fft_size objects must have size less than 6");
DLIB_ASSERT(std::find_if(dims_begin, dims_end, [](long dim) {return dim <= 0;}) == dims_end, "fft_size objects must contain strictly positive values");
std::copy(dims_begin, dims_end, _dims.begin());
_size = ndims;
_num_elements = std::accumulate(dims_begin, dims_end, 1, std::multiplies<long>());
}
fft_size(std::initializer_list<long> dims)
: fft_size(dims.begin(), dims.end())
/*!
requires
- dims.size() > 0 and dims.size() <= 5
- dims contains strictly positive values
ensures
- *this is properly initialised
- num_dims() == dims.size()
- num_elements() == product of all values in dims
!*/
{
}
size_t num_dims() const
/*!
ensures
- returns the number of dimensions
!*/
{
return _size;
}
long num_elements() const
/*!
ensures
- if num_dims() > 0, returns the product of all dimensions, i.e. the total number
of elements
- if num_dims() == 0, returns 0
!*/
{
return _num_elements;
}
const_reference operator[](size_t index) const
/*!
requires
- index < num_dims()
ensures
- returns a const reference to the dimension at position index
!*/
{
DLIB_ASSERT(index < _size, "index " << index << " out of range [0," << _size << ")");
return _dims[index];
}
const_reference back() const
/*!
requires
- num_dims() > 0
ensures
- returns a const reference to (*this)[num_dims()-1]
!*/
{
DLIB_ASSERT(_size > 0, "object is empty");
return _dims[_size-1];
}
const_iterator begin() const
/*!
ensures
- returns a const iterator that points to the first dimension
in this container or end() if the array is empty.
!*/
{
return _dims.begin();
}
const_iterator end() const
/*!
ensures
- returns a const iterator that points to one past the end of
the container.
!*/
{
return _dims.begin() + _size;
}
bool operator==(const fft_size& other) const
/*!
ensures
- returns true if two fft_size objects have same size and same dimensions, i.e. if they have identical states
!*/
{
return this->_size == other._size && std::equal(begin(), end(), other.begin());
}
private:
size_t _size = 0;
size_t _num_elements = 0;
container_type _dims;
};
inline dlib::uint32 hash(
const fft_size& item,
dlib::uint32 seed = 0)
{
seed = dlib::hash((dlib::uint64)item.num_dims(), seed);
seed = std::accumulate(item.begin(), item.end(), seed, [](dlib::uint32 seed, long next) {
return dlib::hash((dlib::uint64)next, seed);
});
return seed;
}
/*!
ensures
- returns a 32bit hash of the data stored in item.
!*/
inline fft_size pop_back(const fft_size& size)
{
DLIB_ASSERT(size.num_dims() > 0);
return fft_size(size.begin(), size.end() - 1);
}
/*!
requires
- num_dims.size() > 0
ensures
- returns a copy of size with the last dimension removed.
!*/
inline fft_size squeeze_ones(const fft_size size)
{
DLIB_ASSERT(size.num_dims() > 0);
fft_size newsize;
if (size.num_elements() == 1)
{
newsize = {1};
}
else
{
fft_size::container_type tmp;
auto end = std::copy_if(size.begin(), size.end(), tmp.begin(), [](long dim){return dim != 1;});
newsize = fft_size(tmp.begin(), end);
}
return newsize;
}
/*!
requires
- num_dims.size() > 0
ensures
- removes dimensions with values equal to 1, yielding a new fft_size object with the same num_elements() but fewer dimensions
!*/
}
#endif //DLIB_FFT_SIZE_H
/*
* Copyright (c) 2003-2010, Mark Borgerding. All rights reserved.
* This file is part of KISS FFT - https://github.com/mborgerding/kissfft
*
* SPDX-License-Identifier: BSD-3-Clause
* See COPYING file for more information.
*/
#ifndef DLIB_KISS_FFT_H
#define DLIB_KISS_FFT_H
#include <complex>
#include <vector>
#include <cmath>
#include <stdexcept>
#include <algorithm>
#include <unordered_map>
#include <mutex>
#include <numeric>
#include "fft_size.h"
#include "../hash.h"
#include "../assert.h"
#define C_FIXDIV(x,y) /*noop*/
namespace dlib
{
namespace kiss_details
{
struct plan_key
{
fft_size dims;
bool is_inverse;
plan_key(const fft_size& dims_, bool is_inverse_)
: dims(dims_), is_inverse(is_inverse_)
{}
bool operator==(const plan_key& other) const
{
return std::tie(dims, is_inverse) == std::tie(other.dims, other.is_inverse);
}
uint32 hash() const
{
using dlib::hash;
uint32 ret = 0;
ret = hash(dims, ret);
ret = hash((uint32)is_inverse, ret);
return ret;
}
};
template<typename T>
struct kiss_fft_state
{
long nfft;
bool inverse;
std::vector<int> factors;
std::vector<std::complex<T>> twiddles;
kiss_fft_state() = default;
kiss_fft_state(const plan_key& key);
};
template<typename T>
struct kiss_fftnd_state
{
fft_size dims;
std::vector<kiss_fft_state<T>> plans;
kiss_fftnd_state() = default;
kiss_fftnd_state(const plan_key& key);
};
template<typename T>
struct kiss_fftr_state
{
kiss_fft_state<T> substate;
std::vector<std::complex<T>> super_twiddles;
kiss_fftr_state() = default;
kiss_fftr_state(const plan_key& key);
};
template<typename T>
struct kiss_fftndr_state
{
kiss_fftr_state<T> cfg_r;
kiss_fftnd_state<T> cfg_nd;
kiss_fftndr_state() = default;
kiss_fftndr_state(const plan_key& key);
};
template<typename T>
inline void kf_bfly2(
std::complex<T> * Fout,
const size_t fstride,
const kiss_fft_state<T>& cfg,
const int m
)
{
const std::complex<T> * tw1 = &cfg.twiddles[0];
std::complex<T> t;
std::complex<T> * Fout2 = Fout + m;
for (int i = 0 ; i < m ; i++)
{
t = Fout2[i] * tw1[i*fstride];
Fout2[i] = Fout[i] - t;
Fout[i] += t;
}
}
template<typename T>
inline std::complex<T> rot_PI_2(std::complex<T> z)
{
return std::complex<T>(z.imag(), -z.real());
}
template<typename T>
inline void kf_bfly3 (
std::complex<T> * Fout,
const size_t fstride,
const kiss_fft_state<T>& cfg,
const size_t m
)
{
const size_t m2 = 2*m;
const std::complex<T> *tw1,*tw2;
std::complex<T> scratch[5];
const std::complex<T> epi3 = cfg.twiddles[fstride*m];
tw1=tw2=&cfg.twiddles[0];
constexpr T half = 0.5;
for (size_t k = 0 ; k < m ; k++)
{
C_FIXDIV(Fout[k],3); C_FIXDIV(Fout[k+m],3); C_FIXDIV(Fout[m2+k],3); //noop for float and double
scratch[1] = Fout[k+m] * tw1[k*fstride];
scratch[2] = Fout[k+m2] * tw2[k*fstride*2];
scratch[3] = scratch[1] + scratch[2];
scratch[0] = scratch[1] - scratch[2];
Fout[m+k] = Fout[k] - half * scratch[3];
scratch[0] *= epi3.imag();
Fout[k] += scratch[3];
Fout[k+m2] = Fout[k+m] + rot_PI_2(scratch[0]);
Fout[k+m] -= rot_PI_2(scratch[0]);
}
}
template<typename T>
inline void kf_bfly4(
std::complex<T> * Fout,
const size_t fstride,
const kiss_fft_state<T>& cfg,
const size_t m
)
{
const std::complex<T> *tw1,*tw2,*tw3;
std::complex<T> scratch[6];
const size_t m2=2*m;
const size_t m3=3*m;
tw3 = tw2 = tw1 = &cfg.twiddles[0];
for (size_t k = 0 ; k < m ; k++)
{
C_FIXDIV(Fout[k],4); C_FIXDIV(Fout[m],4); C_FIXDIV(Fout[m2+k],4); C_FIXDIV(Fout[m3+k],4);
scratch[0] = Fout[m+k] * tw1[k*fstride];
scratch[1] = Fout[m2+k] * tw2[k*fstride*2];
scratch[2] = Fout[m3+k] * tw3[k*fstride*3];
scratch[5] = Fout[k] - scratch[1];
Fout[k] += scratch[1];
scratch[3] = scratch[0] + scratch[2];
scratch[4] = scratch[0] - scratch[2];
Fout[m2+k] = Fout[k] - scratch[3];
Fout[k] += scratch[3];
if(cfg.inverse) {
Fout[m+k] = scratch[5] - rot_PI_2(scratch[4]);
Fout[m3+k] = scratch[5] + rot_PI_2(scratch[4]);
}else {
Fout[m+k] = scratch[5] + rot_PI_2(scratch[4]);
Fout[m3+k] = scratch[5] - rot_PI_2(scratch[4]);
}
}
}
template<typename T>
inline void kf_bfly5(
std::complex<T> * Fout,
const size_t fstride,
const kiss_fft_state<T>& cfg,
const int m
)
{
std::complex<T> scratch[13];
const std::complex<T> * twiddles = &cfg.twiddles[0];
const std::complex<T> ya = twiddles[fstride*m];
const std::complex<T> yb = twiddles[fstride*2*m];
std::complex<T> *Fout0=Fout;
std::complex<T> *Fout1=Fout0+m;
std::complex<T> *Fout2=Fout0+2*m;
std::complex<T> *Fout3=Fout0+3*m;
std::complex<T> *Fout4=Fout0+4*m;
const std::complex<T> *tw = &cfg.twiddles[0];
for (int u=0; u<m; ++u )
{
scratch[0] = Fout0[u];
scratch[1] = Fout1[u] * tw[u*fstride]; //C_MUL(scratch[1] ,*Fout1, tw[u*fstride]);
scratch[2] = Fout2[u] * tw[2*u*fstride]; //C_MUL(scratch[2] ,*Fout2, tw[2*u*fstride]);
scratch[3] = Fout3[u] * tw[3*u*fstride]; //C_MUL(scratch[3] ,*Fout3, tw[3*u*fstride]);
scratch[4] = Fout4[u] * tw[4*u*fstride]; //C_MUL(scratch[4] ,*Fout4, tw[4*u*fstride]);
scratch[7] = scratch[1] + scratch[4]; //C_ADD( scratch[7],scratch[1],scratch[4]);
scratch[10] = scratch[1] - scratch[4]; //C_SUB( scratch[10],scratch[1],scratch[4]);
scratch[8] = scratch[2] + scratch[3]; //C_ADD( scratch[8],scratch[2],scratch[3]);
scratch[9] = scratch[2] - scratch[3]; //C_SUB( scratch[9],scratch[2],scratch[3]);
Fout0[u] += scratch[7] + scratch[8];
scratch[5].real(scratch[0].real() + scratch[7].real() * ya.real() + scratch[8].real() * yb.real());
scratch[5].imag(scratch[0].imag() + scratch[7].imag() * ya.real() + scratch[8].imag() * yb.real());
scratch[6].real(scratch[10].imag() * ya.imag() + scratch[9].imag() * yb.imag());
scratch[6].imag(-scratch[10].real() * ya.imag() - scratch[9].real() * yb.imag());
Fout1[u] = scratch[5] - scratch[6]; //C_SUB(*Fout1,scratch[5],scratch[6]);
Fout4[u] = scratch[5] + scratch[6]; //C_ADD(*Fout4,scratch[5],scratch[6]);
scratch[11].real(scratch[0].real() + scratch[7].real()*yb.real() + scratch[8].real()*ya.real());
scratch[11].imag(scratch[0].imag() + scratch[7].imag()*yb.real() + scratch[8].imag()*ya.real());
scratch[12].real(- scratch[10].imag()*yb.imag() + scratch[9].imag()*ya.imag());
scratch[12].imag(scratch[10].real()*yb.imag() - scratch[9].real()*ya.imag());
Fout2[u] = scratch[11] + scratch[12];
Fout3[u] = scratch[11] - scratch[12];
}
}
/* perform the butterfly for one stage of a mixed radix FFT */
template<typename T>
inline void kf_bfly_generic(
std::complex<T> * Fout,
const size_t fstride,
const kiss_fft_state<T>& cfg,
const int m,
const int p
)
{
int u,k,q1,q;
const std::complex<T> * twiddles = &cfg.twiddles[0];
std::complex<T> t;
const int Norig = cfg.nfft;
std::vector<std::complex<T>> scratch(p);
for ( u=0; u<m; ++u ) {
k=u;
for ( q1=0 ; q1<p ; ++q1 ) {
scratch[q1] = Fout[ k ];
C_FIXDIV(scratch[q1],p);
k += m;
}
k=u;
for ( q1=0 ; q1<p ; ++q1 ) {
int twidx=0;
Fout[ k ] = scratch[0];
for (q=1;q<p;++q ) {
twidx += fstride * k;
if (twidx>=Norig) twidx-=Norig;
t = scratch[q] * twiddles[twidx];
Fout[ k ] += t;
}
k += m;
}
}
}
template<typename T>
inline void kf_work(
const kiss_fft_state<T>& cfg,
const int* factors,
std::complex<T>* Fout,
const std::complex<T>* f,
const size_t fstride,
const int in_stride
)
{
std::complex<T> * Fout_beg = Fout;
const int p=*factors++; /* the radix */
const int m=*factors++; /* stage's fft length/p */
const std::complex<T> * Fout_end = Fout + p*m;
if (m==1) {
do{
*Fout = *f;
f += fstride*in_stride;
}while(++Fout != Fout_end );
}else{
do{
// recursive call:
// DFT of size m*p performed by doing
// p instances of smaller DFTs of size m,
// each one takes a decimated version of the input
kf_work(cfg, factors, Fout , f, fstride*p, in_stride);
f += fstride*in_stride;
}while( (Fout += m) != Fout_end );
}
Fout=Fout_beg;
// recombine the p smaller DFTs
switch (p) {
case 2: kf_bfly2(Fout,fstride,cfg,m); break;
case 3: kf_bfly3(Fout,fstride,cfg,m); break;
case 4: kf_bfly4(Fout,fstride,cfg,m); break;
case 5: kf_bfly5(Fout,fstride,cfg,m); break;
default: kf_bfly_generic(Fout,fstride,cfg,m,p); break;
}
}
/* facbuf is populated by p1,m1,p2,m2, ...
where
p[i] * m[i] = m[i-1]
m0 = n */
inline void kf_factor(int n, std::vector<int>& facbuf)
{
int p=4;
const double floor_sqrt = std::floor( std::sqrt((double)n) );
/*factor out powers of 4, powers of 2, then any remaining primes */
do {
while (n % p) {
switch (p) {
case 4: p = 2; break;
case 2: p = 3; break;
default: p += 2; break;
}
if (p > floor_sqrt)
p = n; /* no more factors, skip to end */
}
n /= p;
facbuf.push_back(p);
facbuf.push_back(n);
} while (n > 1);
}
template<typename T>
inline kiss_fft_state<T>::kiss_fft_state(const plan_key& key)
{
constexpr double twopi = 6.283185307179586476925286766559005768394338798;
nfft = key.dims[0];
inverse = key.is_inverse;
twiddles.resize(nfft);
for (int i = 0 ; i < nfft ; ++i)
{
double phase = -twopi*i / nfft;
if (inverse)
phase *= -1;
twiddles[i] = std::polar(1.0, phase);
}
kf_factor(nfft,factors);
}
template<typename T>
void kiss_fft_stride(const kiss_fft_state<T>& cfg, const std::complex<T>* in, std::complex<T>* out,int fin_stride)
{
if (in == out)
{
DLIB_ASSERT(out != nullptr, "out buffer is NULL!");
std::vector<std::complex<T>> tmpbuf(cfg.nfft);
kiss_fft_stride(cfg, in, &tmpbuf[0], fin_stride);
std::copy(tmpbuf.begin(), tmpbuf.end(), out);
}
else
{
kf_work(cfg, &cfg.factors[0], out, in, 1, fin_stride);
}
}
template<typename T>
inline kiss_fftnd_state<T>::kiss_fftnd_state(const plan_key& key)
{
dims = key.dims;
for (size_t i = 0 ; i < dims.num_dims() ; i++)
plans.push_back(std::move(kiss_fft_state<T>(plan_key({dims[i]}, key.is_inverse))));
}
template<typename T>
void kiss_fftnd(const kiss_fftnd_state<T>& cfg, const std::complex<T>* in, std::complex<T>* out)
{
const std::complex<T>* bufin=in;
std::complex<T>* bufout;
std::vector<std::complex<T>> tmpbuf(cfg.dims.num_elements());
/*arrange it so the last bufout == out*/
if ( cfg.dims.num_dims() & 1 )
{
bufout = out;
if (in==out) {
std::copy(in, in + cfg.dims.num_elements(), tmpbuf.begin());
bufin = &tmpbuf[0];
}
}
else
bufout = &tmpbuf[0];
for (size_t k=0; k < cfg.dims.num_dims(); ++k)
{
int curdim = cfg.dims[k];
int stride = cfg.dims.num_elements() / curdim;
for (int i=0 ; i<stride ; ++i )
kiss_fft_stride(cfg.plans[k], bufin+i , bufout+i*curdim, stride );
/*toggle back and forth between the two buffers*/
if (bufout == &tmpbuf[0])
{
bufout = out;
bufin = &tmpbuf[0];
}
else
{
bufout = &tmpbuf[0];
bufin = out;
}
}
}
template<typename T>
inline kiss_fftr_state<T>::kiss_fftr_state(const plan_key& key)
{
DLIB_ASSERT((key.dims[0] & 1) == 0, "real FFT must have even dimension");
const int nfft = key.dims[0] / 2;
substate = kiss_fft_state<T>(plan_key({nfft}, key.is_inverse));
super_twiddles.resize(nfft/2);
for (size_t i = 0 ; i < super_twiddles.size() ; ++i)
{
double phase = -3.141592653589793238462643383279502884197169399 * ((double) (i+1) / nfft + .5);
if (key.is_inverse)
phase *= -1;
super_twiddles[i] = std::polar(1.0, phase);
}
}
template<typename T>
void kiss_fftr(const kiss_fftr_state<T>& plan, const T* timedata, std::complex<T>* freqdata)
{
DLIB_ASSERT(!plan.substate.inverse, "bad fftr plan : need a forward plan. This is an inverse plan");
const int nfft_h = plan.substate.nfft; //recall that the FFT size is actually half the original requested FFT size, i.e. the size of timedata
/*perform the parallel fft of two real signals packed in real,imag*/
std::vector<std::complex<T>> tmpbuf(nfft_h);
kiss_fft_stride(plan.substate, reinterpret_cast<const std::complex<T>*>(timedata), &tmpbuf[0], 1);
/* The real part of the DC element of the frequency spectrum in st->tmpbuf
* contains the sum of the even-numbered elements of the input time sequence
* The imag part is the sum of the odd-numbered elements
*
* The sum of tdc.r and tdc.i is the sum of the input time sequence.
* yielding DC of input time sequence
* The difference of tdc.r - tdc.i is the sum of the input (dot product) [1,-1,1,-1...
* yielding Nyquist bin of input time sequence
*/
freqdata[0] = std::complex<T>(tmpbuf[0].real() + tmpbuf[0].imag(), 0);
freqdata[nfft_h] = std::complex<T>(tmpbuf[0].real() - tmpbuf[0].imag(), 0);
constexpr T half = 0.5;
for (int k = 1 ; k <= nfft_h / 2 ; ++k)
{
const auto fpk = tmpbuf[k];
const auto fpnk = std::conj(tmpbuf[nfft_h-k]);
const auto f1k = fpk + fpnk;
const auto f2k = fpk - fpnk;
const auto tw = f2k * plan.super_twiddles[k-1];
freqdata[k] = half * (f1k + tw);
freqdata[nfft_h-k] = half * std::conj(f1k - tw);
}
}
template<typename T>
void kiss_ifftr(const kiss_fftr_state<T>& plan, const std::complex<T>* freqdata, T* timedata)
{
DLIB_ASSERT(plan.substate.inverse, "bad Ifftr plan : need an inverse plan. This is a forward plan")
const int nfft_h = plan.substate.nfft; //recall that the FFT size is actually half the original requested FFT size, i.e. the size of timedata
std::vector<std::complex<T>> tmpbuf(nfft_h);
tmpbuf[0] = std::complex<T>(freqdata[0].real() + freqdata[nfft_h].real(),
freqdata[0].real() - freqdata[nfft_h].real());
for (int k = 1; k <= nfft_h / 2; ++k)
{
std::complex<T> fk = freqdata[k];
std::complex<T> fnkc = std::conj(freqdata[nfft_h - k]);
auto fek = fk + fnkc;
auto tmp = fk - fnkc;
auto fok = tmp * plan.super_twiddles[k-1];
tmpbuf[k] = fek + fok;
tmpbuf[nfft_h - k] = std::conj(fek - fok);
}
kiss_fft_stride (plan.substate, &tmpbuf[0], (std::complex<T>*)timedata, 1);
}
template<typename T>
inline kiss_fftndr_state<T>::kiss_fftndr_state(const plan_key& key)
{
const long realdim = key.dims.back();
const fft_size otherdims = pop_back(key.dims);
cfg_r = kiss_fftr_state<T>(plan_key({realdim}, key.is_inverse));
cfg_nd = kiss_fftnd_state<T>(plan_key(otherdims, key.is_inverse));
}
template<typename T>
void kiss_fftndr(const kiss_fftndr_state<T>& plan, const T* timedata, std::complex<T>* freqdata)
{
const int dimReal = plan.cfg_r.substate.nfft*2; //recall the real fft size is half the length of the input
const int dimOther = plan.cfg_nd.dims.num_elements();
const int nrbins = dimReal/2+1;
std::vector<std::complex<T>> tmp1(std::max<int>(nrbins, dimOther));
std::vector<std::complex<T>> tmp2(plan.cfg_nd.dims.num_elements()*dimReal);
// take a real chunk of data, fft it and place the output at correct intervals
for (int k1 = 0; k1 < dimOther; ++k1)
{
kiss_fftr(plan.cfg_r, timedata + k1*dimReal , &tmp1[0]); // tmp1 now holds nrbins complex points
for (int k2 = 0; k2 < nrbins; ++k2)
tmp2[k2*dimOther+k1] = tmp1[k2];
}
for (int k2 = 0; k2 < nrbins; ++k2)
{
kiss_fftnd(plan.cfg_nd, &tmp2[k2*dimOther], &tmp1[0]); // tmp1 now holds dimOther complex points
for (int k1 = 0; k1 < dimOther; ++k1)
freqdata[ k1*(nrbins) + k2] = tmp1[k1];
}
}
template<typename T>
void kiss_ifftndr(const kiss_fftndr_state<T>& plan, const std::complex<T>* freqdata, T* timedata)
{
const int dimReal = plan.cfg_r.substate.nfft*2; //recall the real fft size is half the length of the input
const int dimOther = plan.cfg_nd.dims.num_elements();
const int nrbins = dimReal/2+1;
std::vector<std::complex<T>> tmp1(std::max<int>(nrbins, dimOther));
std::vector<std::complex<T>> tmp2(plan.cfg_nd.dims.num_elements()*dimReal);
for (int k2 = 0; k2 < nrbins; ++k2)
{
for (int k1 = 0; k1 < dimOther; ++k1)
tmp1[k1] = freqdata[ k1*(nrbins) + k2 ];
kiss_fftnd(plan.cfg_nd, &tmp1[0], &tmp2[k2*dimOther]);
}
for (int k1 = 0; k1 < dimOther; ++k1)
{
for (int k2 = 0; k2 < nrbins; ++k2)
tmp1[k2] = tmp2[ k2*dimOther+k1 ];
kiss_ifftr(plan.cfg_r, &tmp1[0], timedata + k1*dimReal);
}
}
struct hasher
{
size_t operator()(const plan_key& key) const {return key.hash();}
};
template<typename plan_type>
const plan_type& get_plan(const plan_key& key)
{
static std::mutex m;
static std::unordered_map<plan_key, plan_type, hasher> plans;
std::lock_guard<std::mutex> l(m);
auto it = plans.find(key);
if (it != plans.end())
{
return it->second;
}
else
{
plans[key] = plan_type(key);
return plans[key];
}
}
}
template<typename T>
void kiss_fft(const fft_size& dims, const std::complex<T>* in, std::complex<T>* out, bool is_inverse)
/*!
requires
- T must be either float or double
- dims represents the dimensions of both `in` and `out`
- dims.num_dims() > 0
ensures
- performs an FFT on `in` and stores the result in `out`.
- if `is_inverse` is true, a backward FFT is performed,
otherwise a forward FFT is performed.
!*/
{
using namespace kiss_details;
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floating point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
const fft_size squeezed_dims = squeeze_ones(dims);
if (squeezed_dims.num_elements() == 1)
{
if (in != out)
{
out[0] = in[0];
}
}
else if (squeezed_dims.num_dims() == 1)
{
const auto& plan = get_plan<kiss_fft_state<T>>({squeezed_dims, is_inverse});
kiss_fft_stride(plan, in, out, 1);
}
else
{
const auto& plan = get_plan<kiss_fftnd_state<T>>({squeezed_dims,is_inverse});
kiss_fftnd(plan, in, out);
}
}
/*
* in has dims[0] * dims[1] * ... * dims[-2] * dims[-1] points
* out has dims[0] * dims[1] * ... * dims[-2] * (dims[-1]/2+1) points
*/
template<typename T>
void kiss_fftr(const fft_size& dims, const T* in, std::complex<T>* out)
/*!
requires
- T must be either float or double
- dims represent the dimensions of `in`
- `out` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.back() must be even
ensures
- performs a real FFT on `in` and stores the result in `out`.
!*/
{
using namespace kiss_details;
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floating point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.back() % 2 == 0, "last dimension needs to be even");
const fft_size squeezed_dims = squeeze_ones(dims);
if (squeezed_dims.num_dims() == 1)
{
const auto& plan = get_plan<kiss_fftr_state<T>>({squeezed_dims,false});
kiss_fftr(plan, in, out);
}
else
{
const auto& plan = get_plan<kiss_fftndr_state<T>>({squeezed_dims,false});
kiss_fftndr(plan, in, out);
}
}
/*
* in has dims[0] * dims[1] * ... * dims[-2] * (dims[-1]/2+1) points
* out has dims[0] * dims[1] * ... * dims[-2] * dims[-1] points
*/
template<typename T>
void kiss_ifftr(const fft_size& dims, const std::complex<T>* in, T* out)
/*!
requires
- T must be either float or double
- dims represent the dimensions of `out`
- `in` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.back() must be even
ensures
- performs an inverse real FFT on `in` and stores the result in `out`.
!*/
{
using namespace kiss_details;
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floating point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.back() % 2 == 0, "last dimension needs to be even");
const fft_size squeezed_dims = squeeze_ones(dims);
if (squeezed_dims.num_dims() == 1)
{
const auto& plan = get_plan<kiss_fftr_state<T>>({squeezed_dims,true});
kiss_ifftr(plan, in, out);
}
else
{
const auto& plan = get_plan<kiss_fftndr_state<T>>({squeezed_dims,true});
kiss_ifftndr(plan, in, out);
}
}
inline int kiss_fft_next_fast_size(int n)
{
while(1) {
int m=n;
while ( (m%2) == 0 ) m/=2;
while ( (m%3) == 0 ) m/=3;
while ( (m%5) == 0 ) m/=5;
if (m<=1)
break; /* n is completely factorable by twos, threes, and fives */
n++;
}
return n;
}
inline int kiss_fftr_next_fast_size_real(int n)
{
return kiss_fft_next_fast_size((n+1)>>1) << 1;
}
}
#endif // DLIB_KISS_FFT_H
...@@ -9,836 +9,187 @@ ...@@ -9,836 +9,187 @@
#include "../algs.h" #include "../algs.h"
#ifdef DLIB_USE_MKL_FFT #ifdef DLIB_USE_MKL_FFT
#include <mkl_dfti.h> #include "mkl_fft.h"
#endif #else
#include "kiss_fft.h"
// No using FFTW until it becomes thread safe!
#if 0
#ifdef DLIB_USE_FFTW
#include <fftw3.h>
#endif // DLIB_USE_FFTW
#endif #endif
namespace dlib namespace dlib
{ {
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline bool is_power_of_two ( constexpr bool is_power_of_two (const unsigned long n)
const unsigned long& value
)
{ {
if (value == 0) return n == 0 ? true : (n & (n - 1)) == 0;
return true;
else
return count_bits(value) == 1;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl constexpr long fftr_nc_size(long nc)
{ {
return nc == 0 ? 0 : nc/2+1;
// ------------------------------------------------------------------------------------
/*
The next few functions related to doing FFTs are derived from Stefan
Gustavson's (stegu@itn.liu.se) public domain 2D Fourier transformation code.
The code has a long history, originally a FORTRAN implementation published in:
Programming for Digital Signal Processing, IEEE Press 1979, Section 1, by G. D.
Bergland and M. T. Dolan. In 2003 it was cleaned up and turned into modern C
by Steven Gustavson. Davis King then rewrote it in modern C++ in 2014 and also
changed the transform so that the outputs are identical to those given from FFTW.
*/
// ------------------------------------------------------------------------------------
/* Get binary log of integer argument - exact if n is a power of 2 */
inline long fastlog2(long n)
{
long log = -1;
while(n) {
log++;
n >>= 1;
}
return log ;
}
// ------------------------------------------------------------------------------------
/* Radix-2 iteration subroutine */
template <typename T>
void R2TX(int nthpo, std::complex<T> *c0, std::complex<T> *c1)
{
for(int k=0; k<nthpo; k+=2)
{
std::complex<T> temp = c0[k] + c1[k];
c1[k] = c0[k] - c1[k];
c0[k] = temp;
}
}
// ------------------------------------------------------------------------------------
/* Radix-4 iteration subroutine */
template <typename T>
void R4TX(int nthpo, std::complex<T> *c0, std::complex<T> *c1,
std::complex<T> *c2, std::complex<T> *c3)
{
for(int k=0;k<nthpo;k+=4)
{
std::complex<T> t1, t2, t3, t4;
t1 = c0[k] + c2[k];
t2 = c0[k] - c2[k];
t3 = c1[k] + c3[k];
t4 = c1[k] - c3[k];
c0[k] = t1 + t3;
c1[k] = t1 - t3;
c2[k] = std::complex<T>(t2.real()-t4.imag(), t2.imag()+t4.real());
c3[k] = std::complex<T>(t2.real()+t4.imag(), t2.imag()-t4.real());
}
}
// ------------------------------------------------------------------------------------
template <typename T>
class twiddles
{
/*!
The point of this object is to cache the twiddle values so we don't
recompute them over and over inside R8TX().
!*/
public:
twiddles()
{
data.resize(64);
}
const std::complex<T>* get_twiddles (
int p
)
/*!
requires
- 0 <= p <= 64
ensures
- returns a pointer to the twiddle factors needed by R8TX if nxtlt == 2^p
!*/
{
// Compute the twiddle factors for this p value if we haven't done so
// already.
if (data[p].size() == 0)
{
const int nxtlt = 0x1 << p;
data[p].reserve(nxtlt*7);
const T twopi = 6.2831853071795865; /* 2.0 * pi */
const T scale = twopi/(nxtlt*8.0);
std::complex<T> cs[7];
for (int j = 0; j < nxtlt; ++j)
{
const T arg = j*scale;
cs[0] = std::complex<T>(std::cos(arg),std::sin(arg));
cs[1] = cs[0]*cs[0];
cs[2] = cs[1]*cs[0];
cs[3] = cs[1]*cs[1];
cs[4] = cs[2]*cs[1];
cs[5] = cs[2]*cs[2];
cs[6] = cs[3]*cs[2];
data[p].insert(data[p].end(), cs, cs+7);
}
}
return &data[p][0];
}
private:
std::vector<std::vector<std::complex<T> > > data;
};
// ----------------------------------------------------------------------------------------
/* Radix-8 iteration subroutine */
template <typename T>
void R8TX(int nxtlt, int nthpo, int length, const std::complex<T>* cs,
std::complex<T> *cc0, std::complex<T> *cc1, std::complex<T> *cc2, std::complex<T> *cc3,
std::complex<T> *cc4, std::complex<T> *cc5, std::complex<T> *cc6, std::complex<T> *cc7)
{
const T irt2 = 0.707106781186548; /* 1.0/sqrt(2.0) */
for(int j=0; j<nxtlt; j++)
{
for(int k=j;k<nthpo;k+=length)
{
std::complex<T> a0, a1, a2, a3, a4, a5, a6, a7;
std::complex<T> b0, b1, b2, b3, b4, b5, b6, b7;
a0 = cc0[k] + cc4[k];
a1 = cc1[k] + cc5[k];
a2 = cc2[k] + cc6[k];
a3 = cc3[k] + cc7[k];
a4 = cc0[k] - cc4[k];
a5 = cc1[k] - cc5[k];
a6 = cc2[k] - cc6[k];
a7 = cc3[k] - cc7[k];
b0 = a0 + a2;
b1 = a1 + a3;
b2 = a0 - a2;
b3 = a1 - a3;
b4 = std::complex<T>(a4.real()-a6.imag(), a4.imag()+a6.real());
b5 = std::complex<T>(a5.real()-a7.imag(), a5.imag()+a7.real());
b6 = std::complex<T>(a4.real()+a6.imag(), a4.imag()-a6.real());
b7 = std::complex<T>(a5.real()+a7.imag(), a5.imag()-a7.real());
const std::complex<T> tmp0(-b3.imag(), b3.real());
const std::complex<T> tmp1(irt2*(b5.real()-b5.imag()), irt2*(b5.real()+b5.imag()));
const std::complex<T> tmp2(-irt2*(b7.real()+b7.imag()), irt2*(b7.real()-b7.imag()));
cc0[k] = b0 + b1;
cc1[k] = b0 - b1;
cc2[k] = b2 + tmp0;
cc3[k] = b2 - tmp0;
cc4[k] = b4 + tmp1;
cc5[k] = b4 - tmp1;
cc6[k] = b6 + tmp2;
cc7[k] = b6 - tmp2;
if(j>0)
{
cc1[k] *= cs[3];
cc2[k] *= cs[1];
cc3[k] *= cs[5];
cc4[k] *= cs[0];
cc5[k] *= cs[4];
cc6[k] *= cs[2];
cc7[k] *= cs[6];
}
}
cs += 7;
}
}
// ------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename layout>
void fft1d_inplace(matrix<std::complex<T>,NR,NC,MM,layout>& data, bool do_backward_fft, twiddles<T>& cs)
/*!
requires
- is_vector(data) == true
- is_power_of_two(data.size()) == true
ensures
- This routine replaces the input std::complex<double> vector by its finite
discrete complex fourier transform if do_backward_fft==true. It replaces
the input std::complex<double> vector by its finite discrete complex
inverse fourier transform if do_backward_fft==false.
The implementation is a radix-2 FFT, but with faster shortcuts for
radix-4 and radix-8. It performs as many radix-8 iterations as possible,
and then finishes with a radix-2 or -4 iteration if needed.
!*/
{
COMPILE_TIME_ASSERT((is_same_type<double,T>::value || is_same_type<float,T>::value || is_same_type<long double,T>::value ));
if (data.size() == 0)
return;
std::complex<T>* const b = &data(0);
int L[16],L1,L2,L3,L4,L5,L6,L7,L8,L9,L10,L11,L12,L13,L14,L15;
int j1,j2,j3,j4,j5,j6,j7,j8,j9,j10,j11,j12,j13,j14;
int j, ij, ji;
int n2pow, n8pow, nthpo, ipass, nxtlt, length;
n2pow = fastlog2(data.size());
nthpo = data.size();
n8pow = n2pow/3;
if(n8pow)
{
/* Radix 8 iterations */
for(ipass=1;ipass<=n8pow;ipass++)
{
const int p = n2pow - 3*ipass;
nxtlt = 0x1 << p;
length = 8*nxtlt;
R8TX(nxtlt, nthpo, length, cs.get_twiddles(p),
b, b+nxtlt, b+2*nxtlt, b+3*nxtlt,
b+4*nxtlt, b+5*nxtlt, b+6*nxtlt, b+7*nxtlt);
}
}
if(n2pow%3 == 1)
{
/* A final radix 2 iteration is needed */
R2TX(nthpo, b, b+1);
}
if(n2pow%3 == 2)
{
/* A final radix 4 iteration is needed */
R4TX(nthpo, b, b+1, b+2, b+3);
}
for(j=1;j<=15;j++)
{
L[j] = 1;
if(j-n2pow <= 0) L[j] = 0x1 << (n2pow + 1 - j);
}
L15=L[1];L14=L[2];L13=L[3];L12=L[4];L11=L[5];L10=L[6];L9=L[7];
L8=L[8];L7=L[9];L6=L[10];L5=L[11];L4=L[12];L3=L[13];L2=L[14];L1=L[15];
ij = 0;
for(j1=0;j1<L1;j1++)
for(j2=j1;j2<L2;j2+=L1)
for(j3=j2;j3<L3;j3+=L2)
for(j4=j3;j4<L4;j4+=L3)
for(j5=j4;j5<L5;j5+=L4)
for(j6=j5;j6<L6;j6+=L5)
for(j7=j6;j7<L7;j7+=L6)
for(j8=j7;j8<L8;j8+=L7)
for(j9=j8;j9<L9;j9+=L8)
for(j10=j9;j10<L10;j10+=L9)
for(j11=j10;j11<L11;j11+=L10)
for(j12=j11;j12<L12;j12+=L11)
for(j13=j12;j13<L13;j13+=L12)
for(j14=j13;j14<L14;j14+=L13)
for(ji=j14;ji<L15;ji+=L14)
{
if(ij<ji)
swap(b[ij], b[ji]);
ij++;
}
// unscramble outputs
if(!do_backward_fft)
{
for(long i=1, j=data.size()-1; i<data.size()/2; i++,j--)
{
swap(b[j], b[i]);
}
}
}
// ------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L >
void fft2d_inplace(
matrix<std::complex<T>,NR,NC,MM,L>& data,
bool do_backward_fft
)
{
if (data.size() == 0)
return;
matrix<std::complex<double> > buff;
twiddles<double> cs;
// Compute transform row by row
for(long r=0; r<data.nr(); ++r)
{
buff = matrix_cast<std::complex<double> >(rowm(data,r));
fft1d_inplace(buff, do_backward_fft, cs);
set_rowm(data,r) = matrix_cast<std::complex<T> >(buff);
}
// Compute transform column by column
for(long c=0; c<data.nc(); ++c)
{
buff = matrix_cast<std::complex<double> >(colm(data,c));
fft1d_inplace(buff, do_backward_fft, cs);
set_colm(data,c) = matrix_cast<std::complex<T> >(buff);
}
}
// ----------------------------------------------------------------------------------------
template <
typename EXP,
typename T
>
void fft2d(
const matrix_exp<EXP>& data,
matrix<std::complex<T> >& data_out,
bool do_backward_fft
)
{
// make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
"\t matrix fft(data)"
<< "\n\t The number of rows and columns must be powers of two."
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
if (data.size() == 0)
return;
matrix<std::complex<double> > buff;
data_out.set_size(data.nr(), data.nc());
twiddles<double> cs;
// Compute transform row by row
for(long r=0; r<data.nr(); ++r)
{
buff = matrix_cast<std::complex<double> >(rowm(data,r));
fft1d_inplace(buff, do_backward_fft, cs);
set_rowm(data_out,r) = matrix_cast<std::complex<T> >(buff);
}
// Compute transform column by column
for(long c=0; c<data_out.nc(); ++c)
{
buff = matrix_cast<std::complex<double> >(colm(data_out,c));
fft1d_inplace(buff, do_backward_fft, cs);
set_colm(data_out,c) = matrix_cast<std::complex<T> >(buff);
}
}
// ------------------------------------------------------------------------------------
} // end namespace impl
// ----------------------------------------------------------------------------------------
template <typename EXP>
matrix<typename EXP::type> fft (const matrix_exp<EXP>& data)
{
// You have to give a complex matrix
COMPILE_TIME_ASSERT(is_complex<typename EXP::type>::value);
// make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
"\t matrix fft(data)"
<< "\n\t The number of rows and columns must be powers of two."
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
if (data.nr() == 1 || data.nc() == 1)
{
matrix<typename EXP::type> temp(data);
impl::twiddles<typename EXP::type::value_type> cs;
impl::fft1d_inplace(temp, false, cs);
return temp;
}
else
{
matrix<typename EXP::type> temp;
impl::fft2d(data, temp, false);
return temp;
}
} }
template <typename EXP> // ----------------------------------------------------------------------------------------
matrix<typename EXP::type> ifft (const matrix_exp<EXP>& data)
constexpr long ifftr_nc_size(long nc)
{ {
// You have to give a complex matrix return nc == 0 ? 0 : 2*(nc-1);
COMPILE_TIME_ASSERT(is_complex<typename EXP::type>::value);
// make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
"\t matrix ifft(data)"
<< "\n\t The number of rows and columns must be powers of two."
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
matrix<typename EXP::type> temp;
if (data.size() == 0)
return temp;
if (data.nr() == 1 || data.nc() == 1)
{
temp = data;
impl::twiddles<typename EXP::type::value_type> cs;
impl::fft1d_inplace(temp, true, cs);
}
else
{
impl::fft2d(data, temp, true);
}
temp /= data.size();
return temp;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L > template < typename T, long NR, long NC, typename MM, typename L >
typename enable_if_c<NR==1||NC==1>::type fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data) matrix<std::complex<T>,NR,NC,MM,L> fft (const matrix<std::complex<T>,NR,NC,MM,L>& in)
// Note that we don't divide the outputs by data.size() so this isn't quite the inverse.
{ {
// make sure requires clause is not broken //complex FFT
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), static_assert(std::is_floating_point<T>::value, "only support floating point types");
"\t void fft_inplace(data)" matrix<std::complex<T>,NR,NC,MM,L> out(in.nr(), in.nc());
<< "\n\t The number of rows and columns must be powers of two." if (in.size() != 0)
<< "\n\t data.nr(): "<< data.nr() {
<< "\n\t data.nc(): "<< data.nc() #ifdef DLIB_USE_MKL_FFT
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) mkl_fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), false);
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) #else
); kiss_fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), false);
#endif
impl::twiddles<T> cs; }
impl::fft1d_inplace(data, false, cs); return out;
} }
template < typename T, long NR, long NC, typename MM, typename L > // ----------------------------------------------------------------------------------------
typename disable_if_c<NR==1||NC==1>::type fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
// Note that we don't divide the outputs by data.size() so this isn't quite the inverse. template <typename EXP>
typename EXP::matrix_type fft (const matrix_exp<EXP>& data)
{ {
// make sure requires clause is not broken //complex FFT for expression template
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), static_assert(is_complex<typename EXP::type>::value, "input should be complex");
"\t void fft_inplace(data)" typename EXP::matrix_type in(data);
<< "\n\t The number of rows and columns must be powers of two." return fft(in);
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
impl::fft2d_inplace(data, false);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L > template < typename T, long NR, long NC, typename MM, typename L >
typename enable_if_c<NR==1||NC==1>::type ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data) matrix<std::complex<T>,NR,NC,MM,L> ifft (const matrix<std::complex<T>,NR,NC,MM,L>& in)
{ {
// make sure requires clause is not broken //inverse complex FFT
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), static_assert(std::is_floating_point<T>::value, "only support floating point types");
"\t void ifft_inplace(data)" matrix<std::complex<T>,NR,NC,MM,L> out(in.nr(), in.nc());
<< "\n\t The number of rows and columns must be powers of two." if (in.size() != 0)
<< "\n\t data.nr(): "<< data.nr() {
<< "\n\t data.nc(): "<< data.nc() #ifdef DLIB_USE_MKL_FFT
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) mkl_fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), true);
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) #else
); kiss_fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), true);
#endif
impl::twiddles<T> cs; out /= out.size();
impl::fft1d_inplace(data, true, cs); }
return out;
} }
template < typename T, long NR, long NC, typename MM, typename L > // ----------------------------------------------------------------------------------------
typename disable_if_c<NR==1||NC==1>::type ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
template <typename EXP>
typename EXP::matrix_type ifft (const matrix_exp<EXP>& data)
{ {
// make sure requires clause is not broken //inverse complex FFT for expression template
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), static_assert(is_complex<typename EXP::type>::value, "input should be complex");
"\t void ifft_inplace(data)" typename EXP::matrix_type in(data);
<< "\n\t The number of rows and columns must be powers of two." return ifft(in);
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
impl::fft2d_inplace(data, true);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
/* template<typename T, long NR, long NC, typename MM, typename L>
I'm disabling any use of the FFTW bindings because FFTW is, as of this writing, not matrix<std::complex<T>,NR,fftr_nc_size(NC),MM,L> fftr (const matrix<T,NR,NC,MM,L>& in)
threadsafe as a library. This means that if multiple threads were to make
concurrent calls to these fft routines then the program could crash. If at some
point FFTW is fixed I'll turn these bindings back on.
See https://github.com/FFTW/fftw3/issues/16
*/
#if 0
#ifdef DLIB_USE_FFTW
template <long NR, long NC, typename MM, typename L>
matrix<std::complex<double>,NR,NC,MM,L> call_fftw_fft(
const matrix<std::complex<double>,NR,NC,MM,L>& data
)
{ {
// make sure requires clause is not broken //real FFT
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), static_assert(std::is_floating_point<T>::value, "only support floating point types");
"\t matrix fft(data)" DLIB_ASSERT(in.nc() % 2 == 0, "last dimension " << in.nc() << " needs to be even otherwise ifftr(fftr(data)) won't have matching dimensions");
<< "\n\t The number of rows and columns must be powers of two." matrix<std::complex<T>,NR,fftr_nc_size(NC),MM,L> out(in.nr(), fftr_nc_size(in.nc()));
<< "\n\t data.nr(): "<< data.nr() if (in.size() != 0)
<< "\n\t data.nc(): "<< data.nc() {
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr()) #ifdef DLIB_USE_MKL_FFT
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) mkl_fftr({in.nr(),in.nc()}, &in(0,0), &out(0,0));
); #else
kiss_fftr({in.nr(),in.nc()}, &in(0,0), &out(0,0));
if (data.size() == 0) #endif
return data; }
return out;
matrix<std::complex<double>,NR,NC,MM,L> m2(data.nr(),data.nc());
fftw_complex *in, *out;
fftw_plan p;
in = (fftw_complex*)&data(0,0);
out = (fftw_complex*)&m2(0,0);
if (data.nr() == 1 || data.nc() == 1)
p = fftw_plan_dft_1d(data.size(), in, out, FFTW_FORWARD, FFTW_ESTIMATE);
else
p = fftw_plan_dft_2d(data.nr(), data.nc(), in, out, FFTW_FORWARD, FFTW_ESTIMATE);
fftw_execute(p);
fftw_destroy_plan(p);
return m2;
} }
template <long NR, long NC, typename MM, typename L> // ----------------------------------------------------------------------------------------
matrix<std::complex<double>,NR,NC,MM,L> call_fftw_ifft(
const matrix<std::complex<double>,NR,NC,MM,L>& data template <typename EXP>
) matrix<add_complex_t<typename EXP::type>> fftr (const matrix_exp<EXP>& data)
{ {
// make sure requires clause is not broken //real FFT for expression template
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), static_assert(std::is_floating_point<typename EXP::type>::value, "input should be real");
"\t matrix ifft(data)" matrix<typename EXP::type> in(data);
<< "\n\t The number of rows and columns must be powers of two." return fft(in);
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
if (data.size() == 0)
return data;
matrix<std::complex<double>,NR,NC,MM,L> m2(data.nr(),data.nc());
fftw_complex *in, *out;
fftw_plan p;
in = (fftw_complex*)&data(0,0);
out = (fftw_complex*)&m2(0,0);
if (data.nr() == 1 || data.nc() == 1)
p = fftw_plan_dft_1d(data.size(), in, out, FFTW_BACKWARD, FFTW_ESTIMATE);
else
p = fftw_plan_dft_2d(data.nr(), data.nc(), in, out, FFTW_BACKWARD, FFTW_ESTIMATE);
fftw_execute(p);
fftw_destroy_plan(p);
return m2;
} }
// ----------------------------------------------------------------------------------------
// call FFTW for these cases:
inline matrix<std::complex<double>,0,1> fft (const matrix<std::complex<double>,0,1>& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double>,0,1> ifft(const matrix<std::complex<double>,0,1>& data) {return call_fftw_ifft(data)/data.size();}
inline matrix<std::complex<double>,1,0> fft (const matrix<std::complex<double>,1,0>& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double>,1,0> ifft(const matrix<std::complex<double>,1,0>& data) {return call_fftw_ifft(data)/data.size();}
inline matrix<std::complex<double> > fft (const matrix<std::complex<double> >& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double> > ifft(const matrix<std::complex<double> >& data) {return call_fftw_ifft(data)/data.size();}
inline void fft_inplace (matrix<std::complex<double>,0,1>& data) {data = call_fftw_fft(data);}
inline void ifft_inplace(matrix<std::complex<double>,0,1>& data) {data = call_fftw_ifft(data);}
inline void fft_inplace (matrix<std::complex<double>,1,0>& data) {data = call_fftw_fft(data);}
inline void ifft_inplace(matrix<std::complex<double>,1,0>& data) {data = call_fftw_ifft(data);}
inline void fft_inplace (matrix<std::complex<double> >& data) {data = call_fftw_fft(data);}
inline void ifft_inplace(matrix<std::complex<double> >& data) {data = call_fftw_ifft(data);}
#endif // DLIB_USE_FFTW
#endif // end of #if 0
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
#ifdef DLIB_USE_MKL_FFT template<typename T, long NR, long NC, typename MM, typename L>
matrix<T,NR,ifftr_nc_size(NC),MM,L> ifftr (const matrix<std::complex<T>,NR,NC,MM,L>& in)
#define DLIB_DFTI_CHECK_STATUS(s) \
if((s) != 0 && !DftiErrorClass((s), DFTI_NO_ERROR)) \
{ \
throw dlib::error(DftiErrorMessage((s))); \
}
template < long NR, long NC, typename MM, typename L >
matrix<std::complex<double>,NR,NC,MM,L> call_mkl_fft(
const matrix<std::complex<double>,NR,NC,MM,L>& data,
bool do_backward_fft)
{ {
// make sure requires clause is not broken //inverse real FFT
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), static_assert(std::is_floating_point<T>::value, "only support floating point types");
"\t matrix fft(data)" matrix<T,NR,ifftr_nc_size(NC),MM,L> out(in.nr(), ifftr_nc_size(in.nc()));
<< "\n\t The number of rows and columns must be powers of two." if (in.size() != 0)
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
if (data.size() == 0)
return data;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (data.nr() == 1 || data.nc() == 1)
{
status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 1, data.size());
DLIB_DFTI_CHECK_STATUS(status);
}
else
{ {
MKL_LONG size[2]; #ifdef DLIB_USE_MKL_FFT
size[0] = data.nr(); mkl_ifftr({out.nr(),out.nc()}, &in(0,0), &out(0,0));
size[1] = data.nc(); #else
kiss_ifftr({out.nr(),out.nc()}, &in(0,0), &out(0,0));
status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 2, size); #endif
DLIB_DFTI_CHECK_STATUS(status); out /= out.size();
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = size[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
} }
status = DftiSetValue(h, DFTI_PLACEMENT, DFTI_NOT_INPLACE);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
matrix<std::complex<double>,NR,NC,MM,L> out(data.nr(), data.nc());
if (do_backward_fft)
status = DftiComputeBackward(h, (void *)(&data(0, 0)), &out(0,0));
else
status = DftiComputeForward(h, (void *)(&data(0, 0)), &out(0,0));
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
return out; return out;
} }
template < long NR, long NC, typename MM, typename L > // ----------------------------------------------------------------------------------------
void call_mkl_fft_inplace(
matrix<std::complex<double>,NR,NC,MM,L>& data, template <typename EXP>
bool do_backward_fft matrix<remove_complex_t<typename EXP::type>> ifftr (const matrix_exp<EXP>& data)
)
{ {
// make sure requires clause is not broken //inverse real FFT for expression template
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), static_assert(is_complex<typename EXP::type>::value, "input should be complex");
"\t void ifft_inplace(data)" matrix<typename EXP::type> in(data);
<< "\n\t The number of rows and columns must be powers of two." return ifftr(in);
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
if (data.size() == 0)
return;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (data.nr() == 1 || data.nc() == 1)
{
status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 1, data.size());
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
MKL_LONG size[2];
size[0] = data.nr();
size[1] = data.nc();
status = DftiCreateDescriptor(&h, DFTI_DOUBLE, DFTI_COMPLEX, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = size[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
if (do_backward_fft)
status = DftiComputeBackward(h, &data(0, 0));
else
status = DftiComputeForward(h, &data(0, 0));
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
return;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// Call the MKL DFTI implementation in these cases template < typename T, long NR, long NC, typename MM, typename L >
void fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
inline matrix<std::complex<double>,0,1> fft (const matrix<std::complex<double>,0,1>& data)
{
return call_mkl_fft(data, false);
}
inline matrix<std::complex<double>,0,1> ifft(const matrix<std::complex<double>,0,1>& data)
{
return call_mkl_fft(data, true) / data.size();
}
inline matrix<std::complex<double>,1,0> fft (const matrix<std::complex<double>,1,0>& data)
{
return call_mkl_fft(data, false);
}
inline matrix<std::complex<double>,1,0> ifft(const matrix<std::complex<double>,1,0>& data)
{
return call_mkl_fft(data, true) / data.size();
}
inline matrix<std::complex<double> > fft (const matrix<std::complex<double> >& data)
{
return call_mkl_fft(data, false);
}
inline matrix<std::complex<double> > ifft(const matrix<std::complex<double> >& data)
{ {
return call_mkl_fft(data, true) / data.size(); if (data.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), false);
#else
kiss_fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), false);
#endif
}
} }
inline void fft_inplace (matrix<std::complex<double>,0,1>& data) // ----------------------------------------------------------------------------------------
{
call_mkl_fft_inplace(data, false);
}
inline void ifft_inplace(matrix<std::complex<double>,0,1>& data)
{
call_mkl_fft_inplace(data, true);
}
inline void fft_inplace (matrix<std::complex<double>,1,0>& data)
{
call_mkl_fft_inplace(data, false);
}
inline void ifft_inplace(matrix<std::complex<double>,1,0>& data)
{
call_mkl_fft_inplace(data, true);
}
inline void fft_inplace (matrix<std::complex<double> >& data) template < typename T, long NR, long NC, typename MM, typename L >
{ void ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
call_mkl_fft_inplace(data, false);
}
inline void ifft_inplace(matrix<std::complex<double> >& data)
{ {
call_mkl_fft_inplace(data, true); if (data.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), true);
#else
kiss_fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), true);
#endif
}
} }
#endif // DLIB_USE_MKL_FFT
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -11,8 +11,8 @@ namespace dlib ...@@ -11,8 +11,8 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
bool is_power_of_two ( constexpr bool is_power_of_two (
const unsigned long& value const unsigned long value
); );
/*! /*!
ensures ensures
...@@ -21,7 +21,27 @@ namespace dlib ...@@ -21,7 +21,27 @@ namespace dlib
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
constexpr long fftr_nc_size(
long nc
);
/*!
ensures
- returns the output dimension of a 1D real FFT
!*/
// ----------------------------------------------------------------------------------------
constexpr long ifftr_nc_size(
long nc
);
/*!
ensures
- returns the output dimension of an inverse 1D real FFT
!*/
// ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
typename EXP::matrix_type fft ( typename EXP::matrix_type fft (
const matrix_exp<EXP>& data const matrix_exp<EXP>& data
...@@ -29,8 +49,6 @@ namespace dlib ...@@ -29,8 +49,6 @@ namespace dlib
/*! /*!
requires requires
- data contains elements of type std::complex<> that itself contains double, float, or long double. - data contains elements of type std::complex<> that itself contains double, float, or long double.
- is_power_of_two(data.nr()) == true
- is_power_of_two(data.nc()) == true
ensures ensures
- Computes the 1 or 2 dimensional discrete Fourier transform of the given data - Computes the 1 or 2 dimensional discrete Fourier transform of the given data
matrix and returns it. In particular, we return a matrix D such that: matrix and returns it. In particular, we return a matrix D such that:
...@@ -39,9 +57,9 @@ namespace dlib ...@@ -39,9 +57,9 @@ namespace dlib
- D(0,0) == the DC term of the Fourier transform. - D(0,0) == the DC term of the Fourier transform.
- starting with D(0,0), D contains progressively higher frequency components - starting with D(0,0), D contains progressively higher frequency components
of the input data. of the input data.
- ifft(D) == D - ifft(D) == data
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
...@@ -51,8 +69,6 @@ namespace dlib ...@@ -51,8 +69,6 @@ namespace dlib
/*! /*!
requires requires
- data contains elements of type std::complex<> that itself contains double, float, or long double. - data contains elements of type std::complex<> that itself contains double, float, or long double.
- is_power_of_two(data.nr()) == true
- is_power_of_two(data.nc()) == true
ensures ensures
- Computes the 1 or 2 dimensional inverse discrete Fourier transform of the - Computes the 1 or 2 dimensional inverse discrete Fourier transform of the
given data vector and returns it. In particular, we return a matrix D such given data vector and returns it. In particular, we return a matrix D such
...@@ -62,8 +78,47 @@ namespace dlib ...@@ -62,8 +78,47 @@ namespace dlib
- fft(D) == data - fft(D) == data
!*/ !*/
// ----------------------------------------------------------------------------------------
template <typename EXP>
matrix<add_complex_t<typename EXP::type>> fftr (
const matrix_exp<EXP>& data
);
/*!
requires
- data contains elements of type double, float, or long double.
- data.nc() is even
ensures
- Computes the 1 or 2 dimensional real discrete Fourier transform of the given data
matrix and returns it. In particular, we return a matrix D such that:
- D.nr() == data.nr()
- D.nc() == fftr_nc_size(data.nc())
- D(0,0) == the DC term of the Fourier transform.
- starting with D(0,0), D contains progressively higher frequency components
of the input data.
- ifftr(D) == data
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename EXP>
matrix<remove_complex_t<typename EXP::type>> ifftr (
const matrix_exp<EXP>& data
);
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- Computes the 1 or 2 dimensional inverse real discrete Fourier transform of the
given data vector and returns it. In particular, we return a matrix D such
that:
- D.nr() == data.nr()
- D.nc() == ifftr_nc_size(data.nc())
- fftr(D) == data
!*/
// ----------------------------------------------------------------------------------------
template < template <
typename T, typename T,
long NR, long NR,
...@@ -77,8 +132,6 @@ namespace dlib ...@@ -77,8 +132,6 @@ namespace dlib
/*! /*!
requires requires
- data contains elements of type std::complex<> that itself contains double, float, or long double. - data contains elements of type std::complex<> that itself contains double, float, or long double.
- is_power_of_two(data.nr()) == true
- is_power_of_two(data.nc()) == true
ensures ensures
- This function is identical to fft() except that it does the FFT in-place. - This function is identical to fft() except that it does the FFT in-place.
That is, after this function executes we will have: That is, after this function executes we will have:
...@@ -100,8 +153,6 @@ namespace dlib ...@@ -100,8 +153,6 @@ namespace dlib
/*! /*!
requires requires
- data contains elements of type std::complex<> that itself contains double, float, or long double. - data contains elements of type std::complex<> that itself contains double, float, or long double.
- is_power_of_two(data.nr()) == true
- is_power_of_two(data.nc()) == true
ensures ensures
- This function is identical to ifft() except that it does the inverse FFT - This function is identical to ifft() except that it does the inverse FFT
in-place. That is, after this function executes we will have: in-place. That is, after this function executes we will have:
......
...@@ -48,6 +48,40 @@ namespace dlib ...@@ -48,6 +48,40 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
/*!A remove_complex
This is a template that can be used to remove std::complex from the underlying type.
For example:
remove_complex<float>::type == float
remove_complex<std::complex<float> >::type == float
!*/
template <typename T>
struct remove_complex {typedef T type;};
template <typename T>
struct remove_complex<std::complex<T> > {typedef T type;};
template<typename T>
using remove_complex_t = typename remove_complex<T>::type;
// ----------------------------------------------------------------------------------------
/*!A add_complex
This is a template that can be used to add std::complex to the underlying type if it isn't already complex.
For example:
add_complex<float>::type == std::complex<float>
add_complex<std::complex<float> >::type == std::complex<float>
!*/
template <typename T>
struct add_complex {typedef std::complex<T> type;};
template <typename T>
struct add_complex<std::complex<T> > {typedef std::complex<T> type;};
template<typename T>
using add_complex_t = typename add_complex<T>::type;
// ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
inline bool is_row_vector ( inline bool is_row_vector (
const matrix_exp<EXP>& m const matrix_exp<EXP>& m
......
#ifndef DLIB_MKL_FFT_H
#define DLIB_MKL_FFT_H
#include <type_traits>
#include <mkl_dfti.h>
#include "fft_size.h"
#define DLIB_DFTI_CHECK_STATUS(s) \
if((s) != 0 && !DftiErrorClass((s), DFTI_NO_ERROR)) \
{ \
throw dlib::error(DftiErrorMessage((s))); \
}
namespace dlib
{
template<typename T>
void mkl_fft(const fft_size& dims, const std::complex<T>* in, std::complex<T>* out, bool is_inverse)
/*!
requires
- T must be either float or double
- dims represents the dimensions of both `in` and `out`
- dims.num_dims() > 0
- dims.num_dims() < 3
ensures
- performs an FFT on `in` and stores the result in `out`.
- if `is_inverse` is true, a backward FFT is performed,
otherwise a forward FFT is performed.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_COMPLEX, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_COMPLEX, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = size[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
const DFTI_CONFIG_VALUE inplacefft = in == out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
if (is_inverse)
status = DftiComputeBackward(h, (void*)in, (void*)out);
else
status = DftiComputeForward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
/*
* in has dims[0] * dims[1] * ... * dims[-2] * dims[-1] points
* out has dims[0] * dims[1] * ... * dims[-2] * (dims[-1]/2+1) points
*/
template<typename T>
void mkl_fftr(const fft_size& dims, const T* in, std::complex<T>* out)
/*!
requires
- T must be either float or double
- dims represent the dimensions of `in`
- `out` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.num_dims() <= 3
- dims.back() must be even
ensures
- performs a real FFT on `in` and stores the result in `out`.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
DLIB_ASSERT(dims.back() % 2 == 0, "last dimension needs to be even");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
const long lastdim = dims[1]/2+1;
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = size[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = lastdim;
strides[2] = 1;
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
}
const DFTI_CONFIG_VALUE inplacefft = (void*)in == (void*)out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiComputeForward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
/*
* in has dims[0] * dims[1] * ... * dims[-2] * (dims[-1]/2+1) points
* out has dims[0] * dims[1] * ... * dims[-2] * dims[-1] points
*/
template<typename T>
void mkl_ifftr(const fft_size& dims, const std::complex<T>* in, T* out)
/*!
requires
- T must be either float or double
- dims represent the dimensions of `out`
- `in` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.num_dims() <= 3
- dims.back() must be even
ensures
- performs an inverse real FFT on `in` and stores the result in `out`.
!*/
{
static_assert(std::is_floating_point<T>::value, "template parameter needs to be a floatint point type");
DLIB_ASSERT(dims.num_dims() > 0, "dims can't be empty");
DLIB_ASSERT(dims.num_dims() < 3, "we currently only support up to 2D FFT. Please submit an issue on github if 3D or above is required.");
DLIB_ASSERT(dims.back() % 2 == 0, "last dimension needs to be even");
constexpr DFTI_CONFIG_VALUE dfti_type = std::is_same<T,float>::value ? DFTI_SINGLE : DFTI_DOUBLE;
DFTI_DESCRIPTOR_HANDLE h;
MKL_LONG status;
if (dims.num_dims() == 1)
{
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 1, dims[0]);
DLIB_DFTI_CHECK_STATUS(status);
}
else
{
const long lastdim = dims[1]/2+1;
MKL_LONG size[] = {dims[0], dims[1]};
status = DftiCreateDescriptor(&h, dfti_type, DFTI_REAL, 2, size);
DLIB_DFTI_CHECK_STATUS(status);
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = lastdim;
strides[2] = 1;
status = DftiSetValue(h, DFTI_INPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
{
MKL_LONG strides[3];
strides[0] = 0;
strides[1] = dims[1];
strides[2] = 1;
status = DftiSetValue(h, DFTI_OUTPUT_STRIDES, strides);
DLIB_DFTI_CHECK_STATUS(status);
}
}
const DFTI_CONFIG_VALUE inplacefft = (void*)in == (void*)out ? DFTI_INPLACE : DFTI_NOT_INPLACE;
status = DftiSetValue(h, DFTI_PLACEMENT, inplacefft);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiSetValue(h, DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
DLIB_DFTI_CHECK_STATUS(status);
// Unless we use sequential mode, the fft results are not correct.
status = DftiSetValue(h, DFTI_THREAD_LIMIT, 1);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiCommitDescriptor(h);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiComputeBackward(h, (void*)in, (void*)out);
DLIB_DFTI_CHECK_STATUS(status);
status = DftiFreeDescriptor(&h);
DLIB_DFTI_CHECK_STATUS(status);
}
}
#endif // DLIB_MKL_FFT_H
...@@ -11,7 +11,12 @@ ...@@ -11,7 +11,12 @@
#include <dlib/compress_stream.h> #include <dlib/compress_stream.h>
#include <dlib/base64.h> #include <dlib/base64.h>
#ifdef DLIB_USE_MKL_FFT
#include <dlib/matrix/kiss_fft.h>
#include <dlib/matrix/mkl_fft.h>
#endif
#include "tester.h" #include "tester.h"
#include "fftr_good_data.h"
namespace namespace
{ {
...@@ -21,19 +26,35 @@ namespace ...@@ -21,19 +26,35 @@ namespace
using namespace std; using namespace std;
logger dlog("test.fft"); logger dlog("test.fft");
static dlib::rand rnd(10000);
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
matrix<complex<double> > rand_complex(long nr, long nc) template<typename R>
matrix<complex<R> > rand_complex(long nr, long nc, R scale = 10.0)
{
matrix<complex<R> > m(nr,nc);
for (long r = 0; r < m.nr(); ++r)
{
for (long c = 0; c < m.nc(); ++c)
{
m(r,c) = std::complex<R>(rnd.get_random_gaussian() * scale, rnd.get_random_gaussian() * scale);
}
}
return m;
}
template<typename R>
matrix<R> rand_real(long nr, long nc)
{ {
static dlib::rand rnd; matrix<R> m(nr,nc);
matrix<complex<double> > m(nr,nc);
for (long r = 0; r < m.nr(); ++r) for (long r = 0; r < m.nr(); ++r)
{ {
for (long c = 0; c < m.nc(); ++c) for (long c = 0; c < m.nc(); ++c)
{ {
m(r,c) = complex<double>(rnd.get_random_gaussian()*10, rnd.get_random_gaussian()*10); m(r,c) = rnd.get_random_gaussian() * 10.0;
} }
} }
return m; return m;
...@@ -65,35 +86,78 @@ namespace ...@@ -65,35 +86,78 @@ namespace
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void test_against_saved_good_fftrs()
{
std::stringstream base64_in, decompressed_in, decompressed_out;
dlib::base64 base64_coder;
dlib::compress_stream::kernel_1ea compressor;
base64_in = get_fftr_stringstream();
base64_coder.decode(base64_in, decompressed_in);
compressor.decompress(decompressed_in, decompressed_out);
matrix<double> m1;
matrix<complex<double>> m2;
while (decompressed_out.peek() != EOF)
{
print_spinner();
deserialize(m1,decompressed_out);
deserialize(m2,decompressed_out);
DLIB_TEST(max(norm(fftr(m1)-m2)) < 1e-16);
DLIB_TEST(max(squared(m1-ifftr(m2))) < 1e-16);
}
}
// ----------------------------------------------------------------------------------------
void test_random_ffts() void test_random_ffts()
{ {
for (int iter = 0; iter < 10; ++iter) int test = 0;
for (int nr = 1; nr <= 64; nr++)
{ {
print_spinner(); for (int nc = 1; nc <= 64; nc++)
for (int nr = 1; nr <= 128; nr*=2)
{ {
for (int nc = 1; nc <= 128; nc *= 2) if (++test % 100 == 0)
{ print_spinner();
const matrix<complex<double> > m1 = rand_complex(nr,nc);
const matrix<complex<float> > fm1 = matrix_cast<complex<float> >(rand_complex(nr,nc)); const matrix<complex<double> > m1 = rand_complex<double>(nr,nc);
const matrix<complex<float> > fm1 = rand_complex<float>(nr,nc);
DLIB_TEST(max(norm(ifft(fft(m1))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(fm1))-fm1)) < 1e-7); DLIB_TEST(max(norm(ifft(fft(m1))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(fm1))-fm1)) < 1e-7);
matrix<complex<double> > temp = m1;
matrix<complex<float> > ftemp = fm1; matrix<complex<double> > temp = m1;
fft_inplace(temp); matrix<complex<float> > ftemp = fm1;
fft_inplace(ftemp); fft_inplace(temp);
DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16); fft_inplace(ftemp);
DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7); DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
ifft_inplace(temp); DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
ifft_inplace(ftemp); ifft_inplace(temp);
DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16); ifft_inplace(ftemp);
DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7); DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
} DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
} }
} }
{
// test size 0 matrices.
matrix<complex<double>> temp;
matrix<complex<float>> ftemp;
fft_inplace(temp);
fft_inplace(ftemp);
DLIB_TEST(temp.size() == 0);
DLIB_TEST(ftemp.size() == 0);
DLIB_TEST(fft(temp).size() == 0);
DLIB_TEST(ifft(temp).size() == 0);
matrix<double> rtemp;
DLIB_TEST(fftr(rtemp).size() == 0);
DLIB_TEST(ifftr(temp).size() == 0);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -102,8 +166,8 @@ namespace ...@@ -102,8 +166,8 @@ namespace
void test_real_compile_time_sized_ffts() void test_real_compile_time_sized_ffts()
{ {
print_spinner(); print_spinner();
const matrix<complex<double>,nr,nc> m1 = complex_matrix(real(rand_complex(nr,nc))); const matrix<complex<double>,nr,nc> m1 = complex_matrix(rand_real<double>(nr,nc));
const matrix<complex<float>,nr,nc> fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(nr,nc)))); const matrix<complex<float>,nr,nc> fm1 = complex_matrix(rand_real<float>(nr,nc));
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16); DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7); DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7);
...@@ -122,40 +186,302 @@ namespace ...@@ -122,40 +186,302 @@ namespace
void test_random_real_ffts() void test_random_real_ffts()
{ {
for (int iter = 0; iter < 10; ++iter) int test = 0;
for (int nr = 1; nr <= 64; nr++)
{ {
print_spinner(); for (int nc = 1; nc <= 64; nc++)
for (int nr = 1; nr <= 128; nr*=2)
{ {
for (int nc = 1; nc <= 128; nc *= 2) if (++test % 100 == 0)
{ print_spinner();
const matrix<complex<double> > m1 = complex_matrix(real(rand_complex(nr,nc)));
const matrix<complex<float> > fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(nr,nc)))); const matrix<complex<double> > m1 = complex_matrix(rand_real<double>(nr,nc));
const matrix<complex<float> > fm1 = complex_matrix(rand_real<float>(nr,nc));
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7); DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7);
matrix<complex<double> > temp = m1;
matrix<complex<float> > ftemp = fm1; matrix<complex<double> > temp = m1;
fft_inplace(temp); matrix<complex<float> > ftemp = fm1;
fft_inplace(ftemp); fft_inplace(temp);
DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16); fft_inplace(ftemp);
DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7); DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
ifft_inplace(temp); DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
ifft_inplace(ftemp); ifft_inplace(temp);
DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16); ifft_inplace(ftemp);
DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7); DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
} DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
} }
} }
test_real_compile_time_sized_ffts<16,16>(); test_real_compile_time_sized_ffts<16,16>();
test_real_compile_time_sized_ffts<16,1>(); test_real_compile_time_sized_ffts<16,1>();
test_real_compile_time_sized_ffts<1,16>(); test_real_compile_time_sized_ffts<1,16>();
test_real_compile_time_sized_ffts<480,480>(); //2^5 * 3 * 5
test_real_compile_time_sized_ffts<480,1>(); //2^5 * 3 * 5
test_real_compile_time_sized_ffts<1,480>(); //2^5 * 3 * 5
test_real_compile_time_sized_ffts<131,131>(); //some large prime
test_real_compile_time_sized_ffts<131,1>(); //some large prime
test_real_compile_time_sized_ffts<1,131>(); //some large prime
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template<typename R>
void test_linearity_complex()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 5e-2;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
auto func = [&](long nr, long nc)
{
if (++test % 100 == 0)
print_spinner();
const matrix<complex<R>> m1 = rand_complex<R>(nr,nc);
const matrix<complex<R>> m2 = rand_complex<R>(nr,nc);
const R a1 = rnd.get_double_in_range(-10.0, 10.0);
const R a2 = rnd.get_double_in_range(-10.0, 10.0);
const matrix<complex<R>> m3 = a1*m1 + a2*m2;
const matrix<complex<R>> f1 = fft(m1);
const matrix<complex<R>> f2 = fft(m2);
const matrix<complex<R>> f3 = fft(m3);
R diff = max(norm(f3 - a1*f1 - a2*f2));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (nr,nc) = (" << nr << "," << nc << ")" << " type " << typelabel);
const matrix<complex<R>> m4 = ifft(f3);
diff = max(norm(m4 - m3));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (nr,nc) = (" << nr << "," << nc << ")" << " type " << typelabel);
};
for (int nr = 1; nr <= 64; nr++)
{
for (int nc = 1; nc <= 64; nc++)
{
func(nr,nc);
}
}
//some odd balls...
func(3, 131); print_spinner();
func(123, 103); print_spinner();
}
// ----------------------------------------------------------------------------------------
template<typename R>
void test_linearity_real()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 1e-3;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
auto func = [&](long nr, long nc)
{
if (++test % 100 == 0)
print_spinner();
const matrix<R> m1 = rand_real<R>(nr,nc);
const matrix<R> m2 = rand_real<R>(nr,nc);
const R a1 = rnd.get_double_in_range(-10.0, 10.0);
const R a2 = rnd.get_double_in_range(-10.0, 10.0);
const matrix<R> m3 = a1*m1 + a2*m2;
const matrix<complex<R>> f1 = fftr(m1);
const matrix<complex<R>> f2 = fftr(m2);
const matrix<complex<R>> f3 = fftr(m3);
DLIB_TEST(f1.nr() == m1.nr());
DLIB_TEST(f1.nc() == fftr_nc_size(m1.nc()));
R diff = max(norm(f3 - a1*f1 - a2*f2));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (nr,nc) = (" << nr << "," << nc << ")" << " type " << typelabel);
const matrix<R> m4 = ifftr(f3);
DLIB_TEST(m4.nr() == f3.nr());
DLIB_TEST(m4.nc() == ifftr_nc_size(f3.nc()));
diff = max(squared(m4 - m3));
DLIB_TEST_MSG(diff < tol, "diff " << diff << " not within tol " << tol << " where (nr,nc) = (" << nr << "," << nc << ")" << " type " << typelabel);
};
for (int nr = 2; nr <= 64; nr += 2)
{
for (int nc = 2; nc <= 64; nc += 2)
{
func(nr,nc);
}
}
//some odd balls...
func(89, 102); print_spinner();
func(123, 48); print_spinner();
}
// ----------------------------------------------------------------------------------------
template<typename R>
void test_kronecker_delta_impulse_response()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 1e-3;
int test = 0;
auto func = [&](long nr, long nc)
{
if (++test % 100 == 0)
print_spinner();
matrix<R> ones = dlib::ones_matrix<R>(nr,nc);
matrix<complex<R>> x = dlib::zeros_matrix<complex<R>>(nr,nc);
x(0,0) = 1.0f;
matrix<complex<R>> f = fft(x);
R diff_real = max(squared(real(f) - ones));
R diff_imag = max(squared(imag(f)));
DLIB_TEST(diff_real < tol);
DLIB_TEST(diff_imag < tol);
};
for (int nr = 1; nr <= 64; nr++)
{
for (int nc = 1; nc <= 64; nc++)
{
func(nr,nc);
}
}
//some odd balls...
func(3, 131); print_spinner();
func(123, 103); print_spinner();
}
// ----------------------------------------------------------------------------------------
template<typename R>
void test_time_shift()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 1e-1;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
auto func = [&](size_t size, size_t time_shift)
{
if (++test % 100 == 0)
print_spinner();
matrix<complex<R>> x1 = rand_complex<R>(1,size, 1.0);
matrix<complex<R>> x2 = x1;
std::rotate(x2.begin(), x2.begin() + time_shift, x2.end());
matrix<complex<R>> f1 = fft(x1);
matrix<complex<R>> f2 = fft(x2);
matrix<complex<R>> f2_expected = f1;
for (long i = 0 ; i < f1.size() ; i++)
f2_expected(i) = f1(i)*std::polar<R>(R(1), R(6.283185307179586476925286766559005768394338798*time_shift*i / size));
const auto diff_real = max(squared(real(f2) - real(f2_expected)));
const auto diff_imag = max(squared(imag(f2) - imag(f2_expected)));
DLIB_TEST_MSG(diff_real < tol, typelabel << " diff_real " << diff_real << " size " << size << " shift " << time_shift);
DLIB_TEST_MSG(diff_imag < tol, typelabel << " diff_real " << diff_imag << " size " << size << " shift " << time_shift);
};
for (size_t size = 16 ; size < 64 ; size++)
{
for (size_t time_shift = 10 ; time_shift < size/2 + 1 ; time_shift += 10)
{
func(size, time_shift);
}
}
//some odd balls...
func(3,1);
func(123,16);
func(123,122);
}
// ----------------------------------------------------------------------------------------
template<typename R>
void test_fftr_conjugacy_1D()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-15 : 1e-6;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
auto func = [&](long nc)
{
print_spinner();
matrix<R> m1 = rand_real<R>(1, nc);
matrix<complex<R>> f1 = fftr(m1);
matrix<complex<R>> f2 = fft(complex_matrix(m1));
matrix<complex<R>> f3 = join_rows(f1, conj(fliplr(colm(f1,range(1,f1.nc()-2)))));
const R diff = max(norm(f2-f3));
DLIB_TEST_MSG(diff < tol, typelabel << " diff " << diff << " nr " << m1.nr() << " nc " << m1.nc() << " tol " << tol);
};
//don't start from 2, as that is a special case where fft and fftr
//give the same number of columns.
//Therefore, fiplr(colm(f1,range(1,f1.nc()-2))) wouldn't work
for (long nc = 4 ; nc <= 128 ; nc+=2)
{
func(nc);
}
//some odd balls...
func(480);
func(130);
}
// ----------------------------------------------------------------------------------------
#ifdef DLIB_USE_MKL_FFT
template<typename R>
void test_kiss_vs_mkl()
{
static constexpr double tol = std::is_same<R,double>::value ? 1e-2 : 1e-2;
static constexpr const char* typelabel = std::is_same<R,double>::value ? "double" : "float";
int test = 0;
for (int nr = 2; nr <= 64; nr += 2)
{
for (int nc = 2; nc <= 64; nc += 2)
{
if (++test % 100 == 0)
print_spinner();
std::vector<float> x1(nr*nc), y1(nr*nc), y2(nr*nc);
std::vector<std::complex<float>> f1(nr*(nc/2+1)), f2(nr*(nc/2+1));
for (int i = 0 ; i < (nr*nc) ; i++)
x1[i] = rnd.get_random_gaussian();
kiss_fftr({nr,nc}, &x1[0], &f1[0]);
mkl_fftr({nr,nc}, &x1[0], &f2[0]);
const R diff1 = max(norm(mat(f1) - mat(f2)));
DLIB_TEST_MSG(diff1 < tol, typelabel << " diff1 " << diff1 << " nr " << nr << " nc " << nc);
kiss_ifftr({nr,nc}, &f1[0], &y1[0]);
mkl_ifftr({nr,nc}, &f2[0], &y2[0]);
const R diff2 = max(squared(mat(y1) - mat(y2)));
DLIB_TEST_MSG(diff2 < tol, typelabel << " diff2 " << diff2 << " nr " << nr << " nc " << nc);
}
}
}
#endif
class test_fft : public tester class test_fft : public tester
{ {
public: public:
...@@ -169,8 +495,23 @@ namespace ...@@ -169,8 +495,23 @@ namespace
) )
{ {
test_against_saved_good_ffts(); test_against_saved_good_ffts();
test_against_saved_good_fftrs();
test_random_ffts(); test_random_ffts();
test_random_real_ffts(); test_random_real_ffts();
test_linearity_real<float>();
test_linearity_real<double>();
test_linearity_complex<float>();
test_linearity_complex<double>();
test_kronecker_delta_impulse_response<float>();
test_kronecker_delta_impulse_response<double>();
test_time_shift<float>();
test_time_shift<double>();
test_fftr_conjugacy_1D<float>();
test_fftr_conjugacy_1D<double>();
#ifdef DLIB_USE_MKL_FFT
test_kiss_vs_mkl<float>();
test_kiss_vs_mkl<double>();
#endif
} }
} a; } a;
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment