Commit 0886042c authored by lishen's avatar lishen
Browse files

dlib from github, version=19.24

parent 5b127120
Pipeline #262 failed with stages
in 0 seconds
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ARRAy_
#define DLIB_ARRAy_
#include "array/array_kernel.h"
#include "array/array_tools.h"
#endif // DLIB_ARRAy_
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ARRAY_KERNEl_2_
#define DLIB_ARRAY_KERNEl_2_
#include "array_kernel_abstract.h"
#include "../interfaces/enumerable.h"
#include "../algs.h"
#include "../serialize.h"
#include "../sort.h"
#include "../is_kind.h"
namespace dlib
{
template <
typename T,
typename mem_manager = default_memory_manager
>
class array : public enumerable<T>
{
/*!
INITIAL VALUE
- array_size == 0
- max_array_size == 0
- array_elements == 0
- pos == 0
- last_pos == 0
- _at_start == true
CONVENTION
- array_size == size()
- max_array_size == max_size()
- if (max_array_size > 0)
- array_elements == pointer to max_array_size elements of type T
- else
- array_elements == 0
- if (array_size > 0)
- last_pos == array_elements + array_size - 1
- else
- last_pos == 0
- at_start() == _at_start
- current_element_valid() == pos != 0
- if (current_element_valid()) then
- *pos == element()
!*/
public:
// These typedefs are here for backwards compatibility with old versions of dlib.
typedef array kernel_1a;
typedef array kernel_1a_c;
typedef array kernel_2a;
typedef array kernel_2a_c;
typedef array sort_1a;
typedef array sort_1a_c;
typedef array sort_1b;
typedef array sort_1b_c;
typedef array sort_2a;
typedef array sort_2a_c;
typedef array sort_2b;
typedef array sort_2b_c;
typedef array expand_1a;
typedef array expand_1a_c;
typedef array expand_1b;
typedef array expand_1b_c;
typedef array expand_1c;
typedef array expand_1c_c;
typedef array expand_1d;
typedef array expand_1d_c;
typedef T type;
typedef T value_type;
typedef mem_manager mem_manager_type;
array (
) :
array_size(0),
max_array_size(0),
array_elements(0),
pos(0),
last_pos(0),
_at_start(true)
{}
array(const array&) = delete;
array& operator=(array&) = delete;
array(
array&& item
) : array()
{
swap(item);
}
array& operator=(
array&& item
)
{
swap(item);
return *this;
}
explicit array (
size_t new_size
) :
array_size(0),
max_array_size(0),
array_elements(0),
pos(0),
last_pos(0),
_at_start(true)
{
resize(new_size);
}
~array (
);
void clear (
);
inline const T& operator[] (
size_t pos
) const;
inline T& operator[] (
size_t pos
);
void set_size (
size_t size
);
inline size_t max_size(
) const;
void set_max_size(
size_t max
);
void swap (
array& item
);
// functions from the enumerable interface
inline size_t size (
) const;
inline bool at_start (
) const;
inline void reset (
) const;
bool current_element_valid (
) const;
inline const T& element (
) const;
inline T& element (
);
bool move_next (
) const;
void sort (
);
void resize (
size_t new_size
);
const T& back (
) const;
T& back (
);
void pop_back (
);
void pop_back (
T& item
);
void push_back (
T& item
);
void push_back (
T&& item
);
typedef T* iterator;
typedef const T* const_iterator;
iterator begin() { return array_elements; }
const_iterator begin() const { return array_elements; }
iterator end() { return array_elements+array_size; }
const_iterator end() const { return array_elements+array_size; }
private:
typename mem_manager::template rebind<T>::other pool;
// data members
size_t array_size;
size_t max_array_size;
T* array_elements;
mutable T* pos;
T* last_pos;
mutable bool _at_start;
};
template <
typename T,
typename mem_manager
>
inline void swap (
array<T,mem_manager>& a,
array<T,mem_manager>& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void serialize (
const array<T,mem_manager>& item,
std::ostream& out
)
{
try
{
serialize(item.max_size(),out);
serialize(item.size(),out);
for (size_t i = 0; i < item.size(); ++i)
serialize(item[i],out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type array");
}
}
template <
typename T,
typename mem_manager
>
void deserialize (
array<T,mem_manager>& item,
std::istream& in
)
{
try
{
size_t max_size, size;
deserialize(max_size,in);
deserialize(size,in);
item.set_max_size(max_size);
item.set_size(size);
for (size_t i = 0; i < size; ++i)
deserialize(item[i],in);
}
catch (serialization_error& e)
{
item.clear();
throw serialization_error(e.info + "\n while deserializing object of type array");
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
array<T,mem_manager>::
~array (
)
{
if (array_elements)
{
pool.deallocate_array(array_elements);
}
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
clear (
)
{
reset();
last_pos = 0;
array_size = 0;
if (array_elements)
{
pool.deallocate_array(array_elements);
}
array_elements = 0;
max_array_size = 0;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
const T& array<T,mem_manager>::
operator[] (
size_t pos
) const
{
// make sure requires clause is not broken
DLIB_ASSERT( pos < this->size() ,
"\tconst T& array::operator[]"
<< "\n\tpos must < size()"
<< "\n\tpos: " << pos
<< "\n\tsize(): " << this->size()
<< "\n\tthis: " << this
);
return array_elements[pos];
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
T& array<T,mem_manager>::
operator[] (
size_t pos
)
{
// make sure requires clause is not broken
DLIB_ASSERT( pos < this->size() ,
"\tT& array::operator[]"
<< "\n\tpos must be < size()"
<< "\n\tpos: " << pos
<< "\n\tsize(): " << this->size()
<< "\n\tthis: " << this
);
return array_elements[pos];
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
set_size (
size_t size
)
{
// make sure requires clause is not broken
DLIB_CASSERT(( size <= this->max_size() ),
"\tvoid array::set_size"
<< "\n\tsize must be <= max_size()"
<< "\n\tsize: " << size
<< "\n\tmax size: " << this->max_size()
<< "\n\tthis: " << this
);
reset();
array_size = size;
if (size > 0)
last_pos = array_elements + size - 1;
else
last_pos = 0;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
size_t array<T,mem_manager>::
size (
) const
{
return array_size;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
set_max_size(
size_t max
)
{
reset();
array_size = 0;
last_pos = 0;
if (max != 0)
{
// if new max size is different
if (max != max_array_size)
{
if (array_elements)
{
pool.deallocate_array(array_elements);
}
// try to get more memroy
try { array_elements = pool.allocate_array(max); }
catch (...) { array_elements = 0; max_array_size = 0; throw; }
max_array_size = max;
}
}
// if the array is being made to be zero
else
{
if (array_elements)
pool.deallocate_array(array_elements);
max_array_size = 0;
array_elements = 0;
}
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
size_t array<T,mem_manager>::
max_size (
) const
{
return max_array_size;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
swap (
array<T,mem_manager>& item
)
{
auto array_size_temp = item.array_size;
auto max_array_size_temp = item.max_array_size;
T* array_elements_temp = item.array_elements;
item.array_size = array_size;
item.max_array_size = max_array_size;
item.array_elements = array_elements;
array_size = array_size_temp;
max_array_size = max_array_size_temp;
array_elements = array_elements_temp;
exchange(_at_start,item._at_start);
exchange(pos,item.pos);
exchange(last_pos,item.last_pos);
pool.swap(item.pool);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// enumerable function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
bool array<T,mem_manager>::
at_start (
) const
{
return _at_start;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
reset (
) const
{
_at_start = true;
pos = 0;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
bool array<T,mem_manager>::
current_element_valid (
) const
{
return pos != 0;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
const T& array<T,mem_manager>::
element (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(this->current_element_valid(),
"\tconst T& array::element()"
<< "\n\tThe current element must be valid if you are to access it."
<< "\n\tthis: " << this
);
return *pos;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
T& array<T,mem_manager>::
element (
)
{
// make sure requires clause is not broken
DLIB_ASSERT(this->current_element_valid(),
"\tT& array::element()"
<< "\n\tThe current element must be valid if you are to access it."
<< "\n\tthis: " << this
);
return *pos;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
bool array<T,mem_manager>::
move_next (
) const
{
if (!_at_start)
{
if (pos < last_pos)
{
++pos;
return true;
}
else
{
pos = 0;
return false;
}
}
else
{
_at_start = false;
if (array_size > 0)
{
pos = array_elements;
return true;
}
else
{
return false;
}
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// Yet more functions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
sort (
)
{
if (this->size() > 1)
{
// call the quick sort function for arrays that is in algs.h
dlib::qsort_array(*this,0,this->size()-1);
}
this->reset();
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
resize (
size_t new_size
)
{
if (this->max_size() < new_size)
{
array temp;
temp.set_max_size(new_size);
temp.set_size(new_size);
for (size_t i = 0; i < this->size(); ++i)
{
exchange((*this)[i],temp[i]);
}
temp.swap(*this);
}
else
{
this->set_size(new_size);
}
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
T& array<T,mem_manager>::
back (
)
{
// make sure requires clause is not broken
DLIB_ASSERT( this->size() > 0 ,
"\tT& array::back()"
<< "\n\tsize() must be bigger than 0"
<< "\n\tsize(): " << this->size()
<< "\n\tthis: " << this
);
return (*this)[this->size()-1];
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
const T& array<T,mem_manager>::
back (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT( this->size() > 0 ,
"\tconst T& array::back()"
<< "\n\tsize() must be bigger than 0"
<< "\n\tsize(): " << this->size()
<< "\n\tthis: " << this
);
return (*this)[this->size()-1];
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
pop_back (
T& item
)
{
// make sure requires clause is not broken
DLIB_ASSERT( this->size() > 0 ,
"\tvoid array::pop_back()"
<< "\n\tsize() must be bigger than 0"
<< "\n\tsize(): " << this->size()
<< "\n\tthis: " << this
);
exchange(item,(*this)[this->size()-1]);
this->set_size(this->size()-1);
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
pop_back (
)
{
// make sure requires clause is not broken
DLIB_ASSERT( this->size() > 0 ,
"\tvoid array::pop_back()"
<< "\n\tsize() must be bigger than 0"
<< "\n\tsize(): " << this->size()
<< "\n\tthis: " << this
);
this->set_size(this->size()-1);
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
push_back (
T& item
)
{
if (this->max_size() == this->size())
{
// double the size of the array
array temp;
temp.set_max_size(this->size()*2 + 1);
temp.set_size(this->size()+1);
for (size_t i = 0; i < this->size(); ++i)
{
exchange((*this)[i],temp[i]);
}
exchange(item,temp[temp.size()-1]);
temp.swap(*this);
}
else
{
this->set_size(this->size()+1);
exchange(item,(*this)[this->size()-1]);
}
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array<T,mem_manager>::
push_back (
T&& item
) { push_back(item); }
// ----------------------------------------------------------------------------------------
template <typename T, typename MM>
struct is_array <array<T,MM> >
{
const static bool value = true;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ARRAY_KERNEl_2_
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_ARRAY_KERNEl_ABSTRACT_
#ifdef DLIB_ARRAY_KERNEl_ABSTRACT_
#include "../interfaces/enumerable.h"
#include "../serialize.h"
#include "../algs.h"
namespace dlib
{
template <
typename T,
typename mem_manager = default_memory_manager
>
class array : public enumerable<T>
{
/*!
REQUIREMENTS ON T
T must have a default constructor.
REQUIREMENTS ON mem_manager
must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
mem_manager::type can be set to anything.
POINTERS AND REFERENCES TO INTERNAL DATA
front(), back(), swap(), max_size(), set_size(), and operator[]
functions do not invalidate pointers or references to internal data.
All other functions have no such guarantee.
INITIAL VALUE
size() == 0
max_size() == 0
ENUMERATION ORDER
The enumerator will iterate over the elements of the array in the
order (*this)[0], (*this)[1], (*this)[2], ...
WHAT THIS OBJECT REPRESENTS
This object represents an ordered 1-dimensional array of items,
each item is associated with an integer value. The items are
numbered from 0 though size() - 1 and the operator[] functions
run in constant time.
Also note that unless specified otherwise, no member functions
of this object throw exceptions.
!*/
public:
typedef T type;
typedef T value_type;
typedef mem_manager mem_manager_type;
array (
);
/*!
ensures
- #*this is properly initialized
throws
- std::bad_alloc or any exception thrown by T's constructor
!*/
explicit array (
size_t new_size
);
/*!
ensures
- #*this is properly initialized
- #size() == new_size
- #max_size() == new_size
- All elements of the array will have initial values for their type.
throws
- std::bad_alloc or any exception thrown by T's constructor
!*/
~array (
);
/*!
ensures
- all memory associated with *this has been released
!*/
array(
array&& item
);
/*!
ensures
- move constructs *this from item. Therefore, the state of item is
moved into *this and #item has a valid but unspecified state.
!*/
array& operator=(
array&& item
);
/*!
ensures
- move assigns *this from item. Therefore, the state of item is
moved into *this and #item has a valid but unspecified state.
- returns a reference to #*this
!*/
void clear (
);
/*!
ensures
- #*this has its initial value
throws
- std::bad_alloc or any exception thrown by T's constructor
if this exception is thrown then the array object is unusable
until clear() is called and succeeds
!*/
const T& operator[] (
size_t pos
) const;
/*!
requires
- pos < size()
ensures
- returns a const reference to the element at position pos
!*/
T& operator[] (
size_t pos
);
/*!
requires
- pos < size()
ensures
- returns a non-const reference to the element at position pos
!*/
void set_size (
size_t size
);
/*!
requires
- size <= max_size()
ensures
- #size() == size
- any element with index between 0 and size - 1 which was in the
array before the call to set_size() retains its value and index.
All other elements have undetermined (but valid for their type)
values. (e.g. this object might buffer old T objects and reuse
them without reinitializing them between calls to set_size())
- #at_start() == true
throws
- std::bad_alloc or any exception thrown by T's constructor
may throw this exception if there is not enough memory and
if it does throw then the call to set_size() has no effect
!*/
size_t max_size(
) const;
/*!
ensures
- returns the maximum size of *this
!*/
void set_max_size(
size_t max
);
/*!
ensures
- #max_size() == max
- #size() == 0
- #at_start() == true
throws
- std::bad_alloc or any exception thrown by T's constructor
may throw this exception if there is not enough
memory and if it does throw then max_size() == 0
!*/
void swap (
array<T>& item
);
/*!
ensures
- swaps *this and item
!*/
void sort (
);
/*!
requires
- T must be a type with that is comparable via operator<
ensures
- for all elements in #*this the ith element is <= the i+1 element
- #at_start() == true
throws
- std::bad_alloc or any exception thrown by T's constructor
data may be lost if sort() throws
!*/
void resize (
size_t new_size
);
/*!
ensures
- #size() == new_size
- #max_size() == max(new_size,max_size())
- for all i < size() && i < new_size:
- #(*this)[i] == (*this)[i]
(i.e. All the original elements of *this which were at index
values less than new_size are unmodified.)
- for all valid i >= size():
- #(*this)[i] has an undefined value
(i.e. any new elements of the array have an undefined value)
throws
- std::bad_alloc or any exception thrown by T's constructor.
If an exception is thrown then it has no effect on *this.
!*/
const T& back (
) const;
/*!
requires
- size() != 0
ensures
- returns a const reference to (*this)[size()-1]
!*/
T& back (
);
/*!
requires
- size() != 0
ensures
- returns a non-const reference to (*this)[size()-1]
!*/
void pop_back (
T& item
);
/*!
requires
- size() != 0
ensures
- #size() == size() - 1
- swaps (*this)[size()-1] into item
- All elements with an index less than size()-1 are
unmodified by this operation.
!*/
void pop_back (
);
/*!
requires
- size() != 0
ensures
- #size() == size() - 1
- All elements with an index less than size()-1 are
unmodified by this operation.
!*/
void push_back (
T& item
);
/*!
ensures
- #size() == size()+1
- swaps item into (*this)[#size()-1]
- #back() == item
- #item has some undefined value (whatever happens to
get swapped out of the array)
throws
- std::bad_alloc or any exception thrown by T's constructor.
If an exception is thrown then it has no effect on *this.
!*/
void push_back (T&& item) { push_back(item); }
/*!
enable push_back from rvalues
!*/
typedef T* iterator;
typedef const T* const_iterator;
iterator begin(
);
/*!
ensures
- returns an iterator that points to the first element in this array or
end() if the array is empty.
!*/
const_iterator begin(
) const;
/*!
ensures
- returns a const iterator that points to the first element in this
array or end() if the array is empty.
!*/
iterator end(
);
/*!
ensures
- returns an iterator that points to one past the end of the array.
!*/
const_iterator end(
) const;
/*!
ensures
- returns a const iterator that points to one past the end of the
array.
!*/
private:
// restricted functions
array(array<T>&); // copy constructor
array<T>& operator=(array<T>&); // assignment operator
};
template <
typename T
>
inline void swap (
array<T>& a,
array<T>& b
) { a.swap(b); }
/*!
provides a global swap function
!*/
template <
typename T
>
void serialize (
const array<T>& item,
std::ostream& out
);
/*!
provides serialization support
!*/
template <
typename T
>
void deserialize (
array<T>& item,
std::istream& in
);
/*!
provides deserialization support
!*/
}
#endif // DLIB_ARRAY_KERNEl_ABSTRACT_
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ARRAY_tOOLS_H_
#define DLIB_ARRAY_tOOLS_H_
#include "../assert.h"
#include "array_tools_abstract.h"
namespace dlib
{
template <typename T>
void split_array (
T& a,
T& b,
double frac
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 <= frac && frac <= 1,
"\t void split_array()"
<< "\n\t frac must be between 0 and 1."
<< "\n\t frac: " << frac
);
const unsigned long asize = static_cast<unsigned long>(a.size()*frac);
const unsigned long bsize = a.size()-asize;
b.resize(bsize);
for (unsigned long i = 0; i < b.size(); ++i)
{
swap(b[i], a[i+asize]);
}
a.resize(asize);
}
}
#endif // DLIB_ARRAY_tOOLS_H_
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_ARRAY_tOOLS_ABSTRACT_H_
#ifdef DLIB_ARRAY_tOOLS_ABSTRACT_H_
#include "array_kernel_abstract.h"
namespace dlib
{
template <typename T>
void split_array (
T& a,
T& b,
double frac
);
/*!
requires
- 0 <= frac <= 1
- T must be an array type such as dlib::array or std::vector
ensures
- This function takes the elements of a and splits them into two groups. The
first group remains in a and the second group is put into b. The ordering of
elements in a is preserved. In particular, concatenating #a with #b will
reproduce the original contents of a.
- The elements in a are moved around using global swap(). So they must be
swappable, but do not need to be copyable.
- #a.size() == floor(a.size()*frac)
- #b.size() == a.size()-#a.size()
!*/
}
#endif // DLIB_ARRAY_tOOLS_ABSTRACT_H_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ARRAY2d_
#define DLIB_ARRAY2d_
#include "array2d/array2d_kernel.h"
#include "array2d/serialize_pixel_overloads.h"
#include "array2d/array2d_generic_image.h"
#endif // DLIB_ARRAY2d_
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ARRAY2D_GENERIC_iMAGE_Hh_
#define DLIB_ARRAY2D_GENERIC_iMAGE_Hh_
#include "array2d_kernel.h"
#include "../image_processing/generic_image.h"
namespace dlib
{
template <typename T, typename mm>
struct image_traits<array2d<T,mm> >
{
typedef T pixel_type;
};
template <typename T, typename mm>
struct image_traits<const array2d<T,mm> >
{
typedef T pixel_type;
};
template <typename T, typename mm>
inline long num_rows( const array2d<T,mm>& img) { return img.nr(); }
template <typename T, typename mm>
inline long num_columns( const array2d<T,mm>& img) { return img.nc(); }
template <typename T, typename mm>
inline void set_image_size(
array2d<T,mm>& img,
long rows,
long cols
) { img.set_size(rows,cols); }
template <typename T, typename mm>
inline void* image_data(
array2d<T,mm>& img
)
{
if (img.size() != 0)
return &img[0][0];
else
return 0;
}
template <typename T, typename mm>
inline const void* image_data(
const array2d<T,mm>& img
)
{
if (img.size() != 0)
return &img[0][0];
else
return 0;
}
template <typename T, typename mm>
inline long width_step(
const array2d<T,mm>& img
)
{
return img.width_step();
}
}
#endif // DLIB_ARRAY2D_GENERIC_iMAGE_Hh_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ARRAY2D_KERNEl_1_
#define DLIB_ARRAY2D_KERNEl_1_
#include "array2d_kernel_abstract.h"
#include "../algs.h"
#include "../interfaces/enumerable.h"
#include "../serialize.h"
#include "../geometry/rectangle.h"
namespace dlib
{
template <
typename T,
typename mem_manager = default_memory_manager
>
class array2d : public enumerable<T>
{
/*!
INITIAL VALUE
- nc_ == 0
- nr_ == 0
- data == 0
- at_start_ == true
- cur == 0
- last == 0
CONVENTION
- nc_ == nc()
- nr_ == nc()
- if (data != 0) then
- last == a pointer to the last element in the data array
- data == pointer to an array of nc_*nr_ T objects
- else
- nc_ == 0
- nr_ == 0
- data == 0
- last == 0
- nr_ * nc_ == size()
- if (cur == 0) then
- current_element_valid() == false
- else
- current_element_valid() == true
- *cur == element()
- at_start_ == at_start()
!*/
class row_helper;
public:
// These typedefs are here for backwards compatibility with older versions of dlib.
typedef array2d kernel_1a;
typedef array2d kernel_1a_c;
typedef T type;
typedef mem_manager mem_manager_type;
typedef T* iterator;
typedef const T* const_iterator;
// -----------------------------------
class row
{
/*!
CONVENTION
- nc_ == nc()
- for all x < nc_:
- (*this)[x] == data[x]
!*/
friend class array2d<T,mem_manager>;
friend class row_helper;
public:
long nc (
) const { return nc_; }
const T& operator[] (
long column
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(column < nc() && column >= 0,
"\tconst T& array2d::operator[](long column) const"
<< "\n\tThe column index given must be less than the number of columns."
<< "\n\tthis: " << this
<< "\n\tcolumn: " << column
<< "\n\tnc(): " << nc()
);
return data[column];
}
T& operator[] (
long column
)
{
// make sure requires clause is not broken
DLIB_ASSERT(column < nc() && column >= 0,
"\tT& array2d::operator[](long column)"
<< "\n\tThe column index given must be less than the number of columns."
<< "\n\tthis: " << this
<< "\n\tcolumn: " << column
<< "\n\tnc(): " << nc()
);
return data[column];
}
private:
row(T* data_, long cols) : data(data_), nc_(cols) {}
row(row&& r) = default;
row& operator=(row&& r) = default;
T* data = nullptr;
long nc_ = 0;
// restricted functions
row(const row&) = delete;
row& operator=(const row&) = delete;
};
// -----------------------------------
array2d (
) :
data(0),
nc_(0),
nr_(0),
cur(0),
last(0),
at_start_(true)
{
}
array2d(
long rows,
long cols
) :
data(0),
nc_(0),
nr_(0),
cur(0),
last(0),
at_start_(true)
{
// make sure requires clause is not broken
DLIB_ASSERT((cols >= 0 && rows >= 0),
"\t array2d::array2d(long rows, long cols)"
<< "\n\t The array2d can't have negative rows or columns."
<< "\n\t this: " << this
<< "\n\t cols: " << cols
<< "\n\t rows: " << rows
);
set_size(rows,cols);
}
array2d(const array2d&) = delete; // copy constructor
array2d& operator=(const array2d&) = delete; // assignment operator
#ifdef DLIB_HAS_RVALUE_REFERENCES
array2d(array2d&& item) : array2d()
{
swap(item);
}
array2d& operator= (
array2d&& rhs
)
{
swap(rhs);
return *this;
}
#endif
virtual ~array2d (
) { clear(); }
long nc (
) const { return nc_; }
long nr (
) const { return nr_; }
row operator[] (
long row_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(row_ < nr() && row_ >= 0,
"\trow array2d::operator[](long row_)"
<< "\n\tThe row index given must be less than the number of rows."
<< "\n\tthis: " << this
<< "\n\trow_: " << row_
<< "\n\tnr(): " << nr()
);
return row(data+row_*nc_, nc_);
}
const row operator[] (
long row_
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(row_ < nr() && row_ >= 0,
"\tconst row array2d::operator[](long row_) const"
<< "\n\tThe row index given must be less than the number of rows."
<< "\n\tthis: " << this
<< "\n\trow_: " << row_
<< "\n\tnr(): " << nr()
);
return row(data+row_*nc_, nc_);
}
void swap (
array2d& item
)
{
exchange(data,item.data);
exchange(nr_,item.nr_);
exchange(nc_,item.nc_);
exchange(at_start_,item.at_start_);
exchange(cur,item.cur);
exchange(last,item.last);
pool.swap(item.pool);
}
void clear (
)
{
if (data != 0)
{
pool.deallocate_array(data);
nc_ = 0;
nr_ = 0;
data = 0;
at_start_ = true;
cur = 0;
last = 0;
}
}
void set_size (
long rows,
long cols
);
bool at_start (
) const { return at_start_; }
void reset (
) const { at_start_ = true; cur = 0; }
bool current_element_valid (
) const { return (cur != 0); }
const T& element (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tconst T& array2d::element()()"
<< "\n\tYou can only call element() when you are at a valid one."
<< "\n\tthis: " << this
);
return *cur;
}
T& element (
)
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tT& array2d::element()()"
<< "\n\tYou can only call element() when you are at a valid one."
<< "\n\tthis: " << this
);
return *cur;
}
bool move_next (
) const
{
if (cur != 0)
{
if (cur != last)
{
++cur;
return true;
}
cur = 0;
return false;
}
else if (at_start_)
{
cur = data;
at_start_ = false;
return (data != 0);
}
else
{
return false;
}
}
size_t size (
) const { return static_cast<size_t>(nc_) * static_cast<size_t>(nr_); }
long width_step (
) const
{
return nc_*sizeof(T);
}
iterator begin()
{
return data;
}
iterator end()
{
return data+size();
}
const_iterator begin() const
{
return data;
}
const_iterator end() const
{
return data+size();
}
private:
T* data;
long nc_;
long nr_;
typename mem_manager::template rebind<T>::other pool;
mutable T* cur;
T* last;
mutable bool at_start_;
};
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
inline void swap (
array2d<T,mem_manager>& a,
array2d<T,mem_manager>& b
) { a.swap(b); }
template <
typename T,
typename mem_manager
>
void serialize (
const array2d<T,mem_manager>& item,
std::ostream& out
)
{
try
{
// The reason the serialization is a little funny is because we are trying to
// maintain backwards compatibility with an older serialization format used by
// dlib while also encoding things in a way that lets the array2d and matrix
// objects have compatible serialization formats.
serialize(-item.nr(),out);
serialize(-item.nc(),out);
item.reset();
while (item.move_next())
serialize(item.element(),out);
item.reset();
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type array2d");
}
}
template <
typename T,
typename mem_manager
>
void deserialize (
array2d<T,mem_manager>& item,
std::istream& in
)
{
try
{
long nr, nc;
deserialize(nr,in);
deserialize(nc,in);
// this is the newer serialization format
if (nr < 0 || nc < 0)
{
nr *= -1;
nc *= -1;
}
else
{
std::swap(nr,nc);
}
item.set_size(nr,nc);
while (item.move_next())
deserialize(item.element(),in);
item.reset();
}
catch (serialization_error& e)
{
item.clear();
throw serialization_error(e.info + "\n while deserializing object of type array2d");
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename T,
typename mem_manager
>
void array2d<T,mem_manager>::
set_size (
long rows,
long cols
)
{
// make sure requires clause is not broken
DLIB_ASSERT((cols >= 0 && rows >= 0) ,
"\tvoid array2d::set_size(long rows, long cols)"
<< "\n\tThe array2d can't have negative rows or columns."
<< "\n\tthis: " << this
<< "\n\tcols: " << cols
<< "\n\trows: " << rows
);
// set the enumerator back at the start
at_start_ = true;
cur = 0;
// don't do anything if we are already the right size.
if (nc_ == cols && nr_ == rows)
{
return;
}
nc_ = cols;
nr_ = rows;
// free any existing memory
if (data != 0)
{
pool.deallocate_array(data);
data = 0;
}
// now setup this object to have the new size
try
{
if (nr_ > 0)
{
data = pool.allocate_array(nr_*nc_);
last = data + nr_*nc_ - 1;
}
}
catch (...)
{
if (data)
pool.deallocate_array(data);
data = 0;
nc_ = 0;
nr_ = 0;
last = 0;
throw;
}
}
// ----------------------------------------------------------------------------------------
template <typename T, typename MM>
struct is_array2d <array2d<T,MM> >
{
const static bool value = true;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ARRAY2D_KERNEl_1_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_ARRAY2D_KERNEl_ABSTRACT_
#ifdef DLIB_ARRAY2D_KERNEl_ABSTRACT_
#include "../interfaces/enumerable.h"
#include "../serialize.h"
#include "../algs.h"
#include "../geometry/rectangle_abstract.h"
namespace dlib
{
template <
typename T,
typename mem_manager = default_memory_manager
>
class array2d : public enumerable<T>
{
/*!
REQUIREMENTS ON T
T must have a default constructor.
REQUIREMENTS ON mem_manager
must be an implementation of memory_manager/memory_manager_kernel_abstract.h or
must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or
must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h
mem_manager::type can be set to anything.
POINTERS AND REFERENCES TO INTERNAL DATA
No member functions in this object will invalidate pointers
or references to internal data except for the set_size()
and clear() member functions.
INITIAL VALUE
nr() == 0
nc() == 0
ENUMERATION ORDER
The enumerator will iterate over the elements of the array starting
with row 0 and then proceeding to row 1 and so on. Each row will be
fully enumerated before proceeding on to the next row and the elements
in a row will be enumerated beginning with the 0th column, then the 1st
column and so on.
WHAT THIS OBJECT REPRESENTS
This object represents a 2-Dimensional array of objects of
type T.
Also note that unless specified otherwise, no member functions
of this object throw exceptions.
Finally, note that this object stores its data contiguously and in
row major order. Moreover, there is no padding at the end of each row.
This means that its width_step() value is always equal to sizeof(type)*nc().
!*/
public:
// ----------------------------------------
typedef T type;
typedef mem_manager mem_manager_type;
typedef T* iterator;
typedef const T* const_iterator;
// ----------------------------------------
class row
{
/*!
POINTERS AND REFERENCES TO INTERNAL DATA
No member functions in this object will invalidate pointers
or references to internal data.
WHAT THIS OBJECT REPRESENTS
This object represents a row of Ts in an array2d object.
!*/
public:
long nc (
) const;
/*!
ensures
- returns the number of columns in this row
!*/
const T& operator[] (
long column
) const;
/*!
requires
- 0 <= column < nc()
ensures
- returns a const reference to the T in the given column
!*/
T& operator[] (
long column
);
/*!
requires
- 0 <= column < nc()
ensures
- returns a non-const reference to the T in the given column
!*/
private:
// restricted functions
row();
row& operator=(row&);
};
// ----------------------------------------
array2d (
);
/*!
ensures
- #*this is properly initialized
throws
- std::bad_alloc
!*/
array2d(const array2d&) = delete; // copy constructor
array2d& operator=(const array2d&) = delete; // assignment operator
array2d(
array2d&& item
);
/*!
ensures
- Moves the state of item into *this.
- #item is in a valid but unspecified state.
!*/
array2d (
long rows,
long cols
);
/*!
requires
- rows >= 0 && cols >= 0
ensures
- #nc() == cols
- #nr() == rows
- #at_start() == true
- all elements in this array have initial values for their type
throws
- std::bad_alloc
!*/
virtual ~array2d (
);
/*!
ensures
- all resources associated with *this has been released
!*/
void clear (
);
/*!
ensures
- #*this has an initial value for its type
!*/
long nc (
) const;
/*!
ensures
- returns the number of elements there are in a row. i.e. returns
the number of columns in *this
!*/
long nr (
) const;
/*!
ensures
- returns the number of rows in *this
!*/
void set_size (
long rows,
long cols
);
/*!
requires
- rows >= 0 && cols >= 0
ensures
- #nc() == cols
- #nr() == rows
- #at_start() == true
- if (the call to set_size() doesn't change the dimensions of this array) then
- all elements in this array retain their values from before this function was called
- else
- all elements in this array have initial values for their type
throws
- std::bad_alloc
If this exception is thrown then #*this will have an initial
value for its type.
!*/
row operator[] (
long row_index
);
/*!
requires
- 0 <= row_index < nr()
ensures
- returns a non-const row of nc() elements that represents the
given row_index'th row in *this.
!*/
const row operator[] (
long row_index
) const;
/*!
requires
- 0 <= row_index < nr()
ensures
- returns a const row of nc() elements that represents the
given row_index'th row in *this.
!*/
void swap (
array2d& item
);
/*!
ensures
- swaps *this and item
!*/
array2d& operator= (
array2d&& rhs
);
/*!
ensures
- Moves the state of item into *this.
- #item is in a valid but unspecified state.
- returns #*this
!*/
long width_step (
) const;
/*!
ensures
- returns the size of one row of the image, in bytes.
More precisely, return a number N such that:
(char*)&item[0][0] + N == (char*)&item[1][0].
- for dlib::array2d objects, the returned value
is always equal to sizeof(type)*nc(). However,
other objects which implement dlib::array2d style
interfaces might have padding at the ends of their
rows and therefore might return larger numbers.
An example of such an object is the dlib::cv_image.
!*/
iterator begin(
);
/*!
ensures
- returns a random access iterator pointing to the first element in this
object.
- The iterator will iterate over the elements of the object in row major
order.
!*/
iterator end(
);
/*!
ensures
- returns a random access iterator pointing to one past the end of the last
element in this object.
!*/
const_iterator begin(
) const;
/*!
ensures
- returns a random access iterator pointing to the first element in this
object.
- The iterator will iterate over the elements of the object in row major
order.
!*/
const_iterator end(
) const;
/*!
ensures
- returns a random access iterator pointing to one past the end of the last
element in this object.
!*/
};
template <
typename T,
typename mem_manager
>
inline void swap (
array2d<T,mem_manager>& a,
array2d<T,mem_manager>& b
) { a.swap(b); }
/*!
provides a global swap function
!*/
template <
typename T,
typename mem_manager
>
void serialize (
const array2d<T,mem_manager>& item,
std::ostream& out
);
/*!
Provides serialization support. Note that the serialization formats used by the
dlib::matrix and dlib::array2d objects are compatible. That means you can load the
serialized data from one into another and it will work properly.
!*/
template <
typename T,
typename mem_manager
>
void deserialize (
array2d<T,mem_manager>& item,
std::istream& in
);
/*!
provides deserialization support
!*/
}
#endif // DLIB_ARRAY2D_KERNEl_ABSTRACT_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_
#define DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_
#include "array2d_kernel.h"
#include "../pixel.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
/*
This file contains overloads of the serialize functions for array2d object
for the case where they contain simple 8bit POD pixel types. In these
cases we can perform a much faster serialization by writing data in chunks
instead of one pixel at a time (this avoids a lot of function call overhead
inside the iostreams).
*/
// ----------------------------------------------------------------------------------------
template <
typename mem_manager
>
void serialize (
const array2d<rgb_pixel,mem_manager>& item,
std::ostream& out
)
{
try
{
// The reason the serialization is a little funny is because we are trying to
// maintain backwards compatibility with an older serialization format used by
// dlib while also encoding things in a way that lets the array2d and matrix
// objects have compatible serialization formats.
serialize(-item.nr(),out);
serialize(-item.nc(),out);
COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3);
if (item.size() != 0)
out.write((char*)&item[0][0], sizeof(rgb_pixel)*item.size());
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type array2d");
}
}
template <
typename mem_manager
>
void deserialize (
array2d<rgb_pixel,mem_manager>& item,
std::istream& in
)
{
try
{
COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3);
long nr, nc;
deserialize(nr,in);
deserialize(nc,in);
// this is the newer serialization format
if (nr < 0 || nc < 0)
{
nr *= -1;
nc *= -1;
}
else
{
std::swap(nr,nc);
}
item.set_size(nr,nc);
if (item.size() != 0)
in.read((char*)&item[0][0], sizeof(rgb_pixel)*item.size());
}
catch (serialization_error& e)
{
item.clear();
throw serialization_error(e.info + "\n while deserializing object of type array2d");
}
}
// ----------------------------------------------------------------------------------------
template <
typename mem_manager
>
void serialize (
const array2d<bgr_pixel,mem_manager>& item,
std::ostream& out
)
{
try
{
// The reason the serialization is a little funny is because we are trying to
// maintain backwards compatibility with an older serialization format used by
// dlib while also encoding things in a way that lets the array2d and matrix
// objects have compatible serialization formats.
serialize(-item.nr(),out);
serialize(-item.nc(),out);
COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3);
if (item.size() != 0)
out.write((char*)&item[0][0], sizeof(bgr_pixel)*item.size());
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type array2d");
}
}
template <
typename mem_manager
>
void deserialize (
array2d<bgr_pixel,mem_manager>& item,
std::istream& in
)
{
try
{
COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3);
long nr, nc;
deserialize(nr,in);
deserialize(nc,in);
// this is the newer serialization format
if (nr < 0 || nc < 0)
{
nr *= -1;
nc *= -1;
}
else
{
std::swap(nr,nc);
}
item.set_size(nr,nc);
if (item.size() != 0)
in.read((char*)&item[0][0], sizeof(bgr_pixel)*item.size());
}
catch (serialization_error& e)
{
item.clear();
throw serialization_error(e.info + "\n while deserializing object of type array2d");
}
}
// ----------------------------------------------------------------------------------------
template <
typename mem_manager
>
void serialize (
const array2d<hsi_pixel,mem_manager>& item,
std::ostream& out
)
{
try
{
// The reason the serialization is a little funny is because we are trying to
// maintain backwards compatibility with an older serialization format used by
// dlib while also encoding things in a way that lets the array2d and matrix
// objects have compatible serialization formats.
serialize(-item.nr(),out);
serialize(-item.nc(),out);
COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3);
if (item.size() != 0)
out.write((char*)&item[0][0], sizeof(hsi_pixel)*item.size());
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type array2d");
}
}
template <
typename mem_manager
>
void deserialize (
array2d<hsi_pixel,mem_manager>& item,
std::istream& in
)
{
try
{
COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3);
long nr, nc;
deserialize(nr,in);
deserialize(nc,in);
// this is the newer serialization format
if (nr < 0 || nc < 0)
{
nr *= -1;
nc *= -1;
}
else
{
std::swap(nr,nc);
}
item.set_size(nr,nc);
if (item.size() != 0)
in.read((char*)&item[0][0], sizeof(hsi_pixel)*item.size());
}
catch (serialization_error& e)
{
item.clear();
throw serialization_error(e.info + "\n while deserializing object of type array2d");
}
}
// ----------------------------------------------------------------------------------------
template <
typename mem_manager
>
void serialize (
const array2d<rgb_alpha_pixel,mem_manager>& item,
std::ostream& out
)
{
try
{
// The reason the serialization is a little funny is because we are trying to
// maintain backwards compatibility with an older serialization format used by
// dlib while also encoding things in a way that lets the array2d and matrix
// objects have compatible serialization formats.
serialize(-item.nr(),out);
serialize(-item.nc(),out);
COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4);
if (item.size() != 0)
out.write((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size());
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type array2d");
}
}
template <
typename mem_manager
>
void deserialize (
array2d<rgb_alpha_pixel,mem_manager>& item,
std::istream& in
)
{
try
{
COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4);
long nr, nc;
deserialize(nr,in);
deserialize(nc,in);
// this is the newer serialization format
if (nr < 0 || nc < 0)
{
nr *= -1;
nc *= -1;
}
else
{
std::swap(nr,nc);
}
item.set_size(nr,nc);
if (item.size() != 0)
in.read((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size());
}
catch (serialization_error& e)
{
item.clear();
throw serialization_error(e.info + "\n while deserializing object of type array2d");
}
}
// ----------------------------------------------------------------------------------------
template <
typename mem_manager
>
void serialize (
const array2d<unsigned char,mem_manager>& item,
std::ostream& out
)
{
try
{
// The reason the serialization is a little funny is because we are trying to
// maintain backwards compatibility with an older serialization format used by
// dlib while also encoding things in a way that lets the array2d and matrix
// objects have compatible serialization formats.
serialize(-item.nr(),out);
serialize(-item.nc(),out);
if (item.size() != 0)
out.write((char*)&item[0][0], sizeof(unsigned char)*item.size());
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type array2d");
}
}
template <
typename mem_manager
>
void deserialize (
array2d<unsigned char,mem_manager>& item,
std::istream& in
)
{
try
{
long nr, nc;
deserialize(nr,in);
deserialize(nc,in);
// this is the newer serialization format
if (nr < 0 || nc < 0)
{
nr *= -1;
nc *= -1;
}
else
{
std::swap(nr,nc);
}
item.set_size(nr,nc);
if (item.size() != 0)
in.read((char*)&item[0][0], sizeof(unsigned char)*item.size());
}
catch (serialization_error& e)
{
item.clear();
throw serialization_error(e.info + "\n while deserializing object of type array2d");
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ASSERt_
#define DLIB_ASSERt_
#include "config.h"
#include <sstream>
#include <iosfwd>
#include "error.h"
// -----------------------------
// Use some stuff from boost here
// (C) Copyright John Maddock 2001 - 2003.
// (C) Copyright Darin Adler 2001.
// (C) Copyright Peter Dimov 2001.
// (C) Copyright Bill Kempf 2002.
// (C) Copyright Jens Maurer 2002.
// (C) Copyright David Abrahams 2002 - 2003.
// (C) Copyright Gennaro Prota 2003.
// (C) Copyright Eric Friedman 2003.
// License: Boost Software License See LICENSE.txt for the full license.
//
#ifndef DLIB_BOOST_JOIN
#define DLIB_BOOST_JOIN( X, Y ) DLIB_BOOST_DO_JOIN( X, Y )
#define DLIB_BOOST_DO_JOIN( X, Y ) DLIB_BOOST_DO_JOIN2(X,Y)
#define DLIB_BOOST_DO_JOIN2( X, Y ) X##Y
#endif
// figure out if the compiler has rvalue references.
#if defined(__clang__)
# if __has_feature(cxx_rvalue_references)
# define DLIB_HAS_RVALUE_REFERENCES
# endif
# if __has_feature(cxx_generalized_initializers)
# define DLIB_HAS_INITIALIZER_LISTS
# endif
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__)
# define DLIB_HAS_RVALUE_REFERENCES
# define DLIB_HAS_INITIALIZER_LISTS
#elif defined(_MSC_VER) && _MSC_VER >= 1800
# define DLIB_HAS_INITIALIZER_LISTS
# define DLIB_HAS_RVALUE_REFERENCES
#elif defined(_MSC_VER) && _MSC_VER >= 1600
# define DLIB_HAS_RVALUE_REFERENCES
#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X)
# define DLIB_HAS_RVALUE_REFERENCES
# define DLIB_HAS_INITIALIZER_LISTS
#endif
#if defined(__APPLE__) && defined(__GNUC_LIBSTD__) && ((__GNUC_LIBSTD__-0) * 100 + __GNUC_LIBSTD_MINOR__-0 <= 402)
// Apple has not updated libstdc++ in some time and anything under 4.02 does not have <initializer_list> for sure.
# undef DLIB_HAS_INITIALIZER_LISTS
#endif
// figure out if the compiler has static_assert.
#if defined(__clang__)
# if __has_feature(cxx_static_assert)
# define DLIB_HAS_STATIC_ASSERT
# endif
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__)
# define DLIB_HAS_STATIC_ASSERT
#elif defined(_MSC_VER) && _MSC_VER >= 1600
# define DLIB_HAS_STATIC_ASSERT
#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X)
# define DLIB_HAS_STATIC_ASSERT
#endif
// -----------------------------
namespace dlib
{
template <bool value> struct compile_time_assert;
template <> struct compile_time_assert<true> { enum {value=1}; };
template <typename T, typename U> struct assert_are_same_type;
template <typename T> struct assert_are_same_type<T,T> {enum{value=1};};
template <typename T, typename U> struct assert_are_not_same_type {enum{value=1}; };
template <typename T> struct assert_are_not_same_type<T,T> {};
template <typename T, typename U> struct assert_types_match {enum{value=0};};
template <typename T> struct assert_types_match<T,T> {enum{value=1};};
}
// gcc 4.8 will warn about unused typedefs. But we use typedefs in some of the compile
// time assert macros so we need to make it not complain about them "not being used".
#ifdef __GNUC__
#define DLIB_NO_WARN_UNUSED __attribute__ ((unused))
#else
#define DLIB_NO_WARN_UNUSED
#endif
// Use the newer static_assert if it's available since it produces much more readable error
// messages.
#ifdef DLIB_HAS_STATIC_ASSERT
#define COMPILE_TIME_ASSERT(expression) static_assert(expression, "Failed assertion")
#define ASSERT_ARE_SAME_TYPE(type1, type2) static_assert(::dlib::assert_types_match<type1,type2>::value, "These types should be the same but aren't.")
#define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) static_assert(!::dlib::assert_types_match<type1,type2>::value, "These types should NOT be the same.")
#else
#define COMPILE_TIME_ASSERT(expression) \
DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_CTA, __LINE__)[::dlib::compile_time_assert<(bool)(expression)>::value]
#define ASSERT_ARE_SAME_TYPE(type1, type2) \
DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AAST, __LINE__)[::dlib::assert_are_same_type<type1,type2>::value]
#define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) \
DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AANST, __LINE__)[::dlib::assert_are_not_same_type<type1,type2>::value]
#endif
// -----------------------------
#if defined DLIB_DISABLE_ASSERTS
// if DLIB_DISABLE_ASSERTS is on then never enable DLIB_ASSERT no matter what.
#undef ENABLE_ASSERTS
#endif
#if !defined(DLIB_DISABLE_ASSERTS) && ( defined DEBUG || defined _DEBUG)
// make sure ENABLE_ASSERTS is defined if we are indeed using them.
#ifndef ENABLE_ASSERTS
#define ENABLE_ASSERTS
#endif
#endif
// -----------------------------
#ifdef __GNUC__
// There is a bug in version 4.4.5 of GCC on Ubuntu which causes GCC to segfault
// when __PRETTY_FUNCTION__ is used within certain templated functions. So just
// don't use it with this version of GCC.
# if !(__GNUC__ == 4 && __GNUC_MINOR__ == 4 && __GNUC_PATCHLEVEL__ == 5)
# define DLIB_FUNCTION_NAME __PRETTY_FUNCTION__
# else
# define DLIB_FUNCTION_NAME "unknown function"
# endif
#elif defined(_MSC_VER)
#define DLIB_FUNCTION_NAME __FUNCSIG__
#else
#define DLIB_FUNCTION_NAME "unknown function"
#endif
#define DLIBM_CASSERT(_exp,_message) \
{if ( !(_exp) ) \
{ \
dlib_assert_breakpoint(); \
std::ostringstream dlib_o_out; \
dlib_o_out << "\n\nError detected at line " << __LINE__ << ".\n"; \
dlib_o_out << "Error detected in file " << __FILE__ << ".\n"; \
dlib_o_out << "Error detected in function " << DLIB_FUNCTION_NAME << ".\n\n"; \
dlib_o_out << "Failing expression was " << #_exp << ".\n"; \
dlib_o_out << std::boolalpha << _message << "\n"; \
throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \
}}
// This macro is not needed if you have a real C++ compiler. It's here to work around bugs in Visual Studio's preprocessor.
#define DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(x) x
// Make it so the 2nd argument of DLIB_CASSERT is optional. That is, you can call it like
// DLIB_CASSERT(exp) or DLIB_CASSERT(exp,message).
#define DLIBM_CASSERT_1_ARGS(exp) DLIBM_CASSERT(exp,"")
#define DLIBM_CASSERT_2_ARGS(exp,message) DLIBM_CASSERT(exp,message)
#define DLIBM_GET_3TH_ARG(arg1, arg2, arg3, ...) arg3
#define DLIBM_CASSERT_CHOOSER(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_GET_3TH_ARG(__VA_ARGS__, DLIBM_CASSERT_2_ARGS, DLIBM_CASSERT_1_ARGS, DLIB_CASSERT_NEVER_USED))
#define DLIB_CASSERT(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_CASSERT_CHOOSER(__VA_ARGS__)(__VA_ARGS__))
#ifdef ENABLE_ASSERTS
#define DLIB_ASSERT(...) DLIB_CASSERT(__VA_ARGS__)
#define DLIB_IF_ASSERT(exp) exp
#else
#define DLIB_ASSERT(...) {}
#define DLIB_IF_ASSERT(exp)
#endif
// ----------------------------------------------------------------------------------------
/*!A DLIB_ASSERT_HAS_STANDARD_LAYOUT
This macro is meant to cause a compiler error if a type doesn't have a simple
memory layout (like a C struct). In particular, types with simple layouts are
ones which can be copied via memcpy().
This was called a POD type in C++03 and in C++0x we are looking to check if
it is a "standard layout type". Once we can use C++0x we can change this macro
to something that uses the std::is_standard_layout type_traits class.
See: http://www2.research.att.com/~bs/C++0xFAQ.html#PODs
!*/
// Use the fact that in C++03 you can't put non-PODs into a union.
#define DLIB_ASSERT_HAS_STANDARD_LAYOUT(type) \
union DLIB_BOOST_JOIN(DAHSL_,__LINE__) { type TYPE_NOT_STANDARD_LAYOUT; }; \
DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DAHSL2_,__LINE__)[sizeof(DLIB_BOOST_JOIN(DAHSL_,__LINE__))];
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// breakpoints
extern "C"
{
inline void dlib_assert_breakpoint(
) {}
/*!
ensures
- this function does nothing
It exists just so you can put breakpoints on it in a debugging tool.
It is called only when an DLIB_ASSERT or DLIB_CASSERT fails and is about to
throw an exception.
!*/
}
// -----------------------------
#include "stack_trace.h"
#endif // DLIB_ASSERt_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BASe64_
#define DLIB_BASe64_
#include "base64/base64_kernel_1.h"
#endif // DLIB_BASe64_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BASE64_KERNEL_1_CPp_
#define DLIB_BASE64_KERNEL_1_CPp_
#include "base64_kernel_1.h"
#include <iostream>
#include <sstream>
#include <climits>
namespace dlib
{
// ----------------------------------------------------------------------------------------
base64::line_ending_type base64::
line_ending (
) const
{
return eol_style;
}
// ----------------------------------------------------------------------------------------
void base64::
set_line_ending (
line_ending_type eol_style_
)
{
eol_style = eol_style_;
}
// ----------------------------------------------------------------------------------------
base64::
base64 (
) :
encode_table(0),
decode_table(0),
bad_value(100),
eol_style(LF)
{
try
{
encode_table = new char[64];
decode_table = new unsigned char[UCHAR_MAX];
}
catch (...)
{
if (encode_table) delete [] encode_table;
if (decode_table) delete [] decode_table;
throw;
}
// now set up the tables with the right stuff
encode_table[0] = 'A';
encode_table[17] = 'R';
encode_table[34] = 'i';
encode_table[51] = 'z';
encode_table[1] = 'B';
encode_table[18] = 'S';
encode_table[35] = 'j';
encode_table[52] = '0';
encode_table[2] = 'C';
encode_table[19] = 'T';
encode_table[36] = 'k';
encode_table[53] = '1';
encode_table[3] = 'D';
encode_table[20] = 'U';
encode_table[37] = 'l';
encode_table[54] = '2';
encode_table[4] = 'E';
encode_table[21] = 'V';
encode_table[38] = 'm';
encode_table[55] = '3';
encode_table[5] = 'F';
encode_table[22] = 'W';
encode_table[39] = 'n';
encode_table[56] = '4';
encode_table[6] = 'G';
encode_table[23] = 'X';
encode_table[40] = 'o';
encode_table[57] = '5';
encode_table[7] = 'H';
encode_table[24] = 'Y';
encode_table[41] = 'p';
encode_table[58] = '6';
encode_table[8] = 'I';
encode_table[25] = 'Z';
encode_table[42] = 'q';
encode_table[59] = '7';
encode_table[9] = 'J';
encode_table[26] = 'a';
encode_table[43] = 'r';
encode_table[60] = '8';
encode_table[10] = 'K';
encode_table[27] = 'b';
encode_table[44] = 's';
encode_table[61] = '9';
encode_table[11] = 'L';
encode_table[28] = 'c';
encode_table[45] = 't';
encode_table[62] = '+';
encode_table[12] = 'M';
encode_table[29] = 'd';
encode_table[46] = 'u';
encode_table[63] = '/';
encode_table[13] = 'N';
encode_table[30] = 'e';
encode_table[47] = 'v';
encode_table[14] = 'O';
encode_table[31] = 'f';
encode_table[48] = 'w';
encode_table[15] = 'P';
encode_table[32] = 'g';
encode_table[49] = 'x';
encode_table[16] = 'Q';
encode_table[33] = 'h';
encode_table[50] = 'y';
// we can now fill out the decode_table by using the encode_table
for (int i = 0; i < UCHAR_MAX; ++i)
{
decode_table[i] = bad_value;
}
for (unsigned char i = 0; i < 64; ++i)
{
decode_table[(unsigned char)encode_table[i]] = i;
}
}
// ----------------------------------------------------------------------------------------
base64::
~base64 (
)
{
delete [] encode_table;
delete [] decode_table;
}
// ----------------------------------------------------------------------------------------
void base64::
encode (
std::istream& in_,
std::ostream& out_
) const
{
using namespace std;
streambuf& in = *in_.rdbuf();
streambuf& out = *out_.rdbuf();
unsigned char inbuf[3];
unsigned char outbuf[4];
streamsize status = in.sgetn(reinterpret_cast<char*>(&inbuf),3);
unsigned char c1, c2, c3, c4, c5, c6;
int counter = 19;
// while we haven't hit the end of the input stream
while (status != 0)
{
if (counter == 0)
{
counter = 19;
// write a newline
char ch;
switch (eol_style)
{
case CR:
ch = '\r';
if (out.sputn(&ch,1)!=1)
throw std::ios_base::failure("error occurred in the base64 object");
break;
case LF:
ch = '\n';
if (out.sputn(&ch,1)!=1)
throw std::ios_base::failure("error occurred in the base64 object");
break;
case CRLF:
ch = '\r';
if (out.sputn(&ch,1)!=1)
throw std::ios_base::failure("error occurred in the base64 object");
ch = '\n';
if (out.sputn(&ch,1)!=1)
throw std::ios_base::failure("error occurred in the base64 object");
break;
default:
DLIB_CASSERT(false,"this should never happen");
}
}
--counter;
if (status == 3)
{
// encode the bytes in inbuf to base64 and write them to the output stream
c1 = inbuf[0]&0xfc;
c2 = inbuf[0]&0x03;
c3 = inbuf[1]&0xf0;
c4 = inbuf[1]&0x0f;
c5 = inbuf[2]&0xc0;
c6 = inbuf[2]&0x3f;
outbuf[0] = c1>>2;
outbuf[1] = (c2<<4)|(c3>>4);
outbuf[2] = (c4<<2)|(c5>>6);
outbuf[3] = c6;
outbuf[0] = encode_table[outbuf[0]];
outbuf[1] = encode_table[outbuf[1]];
outbuf[2] = encode_table[outbuf[2]];
outbuf[3] = encode_table[outbuf[3]];
// write the encoded bytes to the output stream
if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
{
throw std::ios_base::failure("error occurred in the base64 object");
}
// get 3 more input bytes
status = in.sgetn(reinterpret_cast<char*>(&inbuf),3);
continue;
}
else if (status == 2)
{
// we are at the end of the input stream and need to add some padding
// encode the bytes in inbuf to base64 and write them to the output stream
c1 = inbuf[0]&0xfc;
c2 = inbuf[0]&0x03;
c3 = inbuf[1]&0xf0;
c4 = inbuf[1]&0x0f;
c5 = 0;
outbuf[0] = c1>>2;
outbuf[1] = (c2<<4)|(c3>>4);
outbuf[2] = (c4<<2)|(c5>>6);
outbuf[3] = '=';
outbuf[0] = encode_table[outbuf[0]];
outbuf[1] = encode_table[outbuf[1]];
outbuf[2] = encode_table[outbuf[2]];
// write the encoded bytes to the output stream
if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
{
throw std::ios_base::failure("error occurred in the base64 object");
}
break;
}
else // in this case status must be 1
{
// we are at the end of the input stream and need to add some padding
// encode the bytes in inbuf to base64 and write them to the output stream
c1 = inbuf[0]&0xfc;
c2 = inbuf[0]&0x03;
c3 = 0;
outbuf[0] = c1>>2;
outbuf[1] = (c2<<4)|(c3>>4);
outbuf[2] = '=';
outbuf[3] = '=';
outbuf[0] = encode_table[outbuf[0]];
outbuf[1] = encode_table[outbuf[1]];
// write the encoded bytes to the output stream
if (out.sputn(reinterpret_cast<char*>(&outbuf),4)!=4)
{
throw std::ios_base::failure("error occurred in the base64 object");
}
break;
}
} // while (status != 0)
// make sure the stream buffer flushes to its I/O channel
out.pubsync();
}
// ----------------------------------------------------------------------------------------
void base64::
decode (
std::istream& in_,
std::ostream& out_
) const
{
using namespace std;
streambuf& in = *in_.rdbuf();
streambuf& out = *out_.rdbuf();
unsigned char inbuf[4];
unsigned char outbuf[3];
int inbuf_pos = 0;
streamsize status = in.sgetn(reinterpret_cast<char*>(inbuf),1);
// only count this character if it isn't some kind of filler
if (status == 1 && decode_table[inbuf[0]] != bad_value )
++inbuf_pos;
unsigned char c1, c2, c3, c4, c5, c6;
streamsize outsize;
// while we haven't hit the end of the input stream
while (status != 0)
{
// if we have 4 valid characters
if (inbuf_pos == 4)
{
inbuf_pos = 0;
// this might be the end of the encoded data so we need to figure out if
// there was any padding applied.
outsize = 3;
if (inbuf[3] == '=')
{
if (inbuf[2] == '=')
outsize = 1;
else
outsize = 2;
}
// decode the incoming characters
inbuf[0] = decode_table[inbuf[0]];
inbuf[1] = decode_table[inbuf[1]];
inbuf[2] = decode_table[inbuf[2]];
inbuf[3] = decode_table[inbuf[3]];
// now pack these guys into bytes rather than 6 bit chunks
c1 = inbuf[0]<<2;
c2 = inbuf[1]>>4;
c3 = inbuf[1]<<4;
c4 = inbuf[2]>>2;
c5 = inbuf[2]<<6;
c6 = inbuf[3];
outbuf[0] = c1|c2;
outbuf[1] = c3|c4;
outbuf[2] = c5|c6;
// write the encoded bytes to the output stream
if (out.sputn(reinterpret_cast<char*>(&outbuf),outsize)!=outsize)
{
throw std::ios_base::failure("error occurred in the base64 object");
}
}
// get more input characters
status = in.sgetn(reinterpret_cast<char*>(inbuf + inbuf_pos),1);
// only count this character if it isn't some kind of filler
if ((decode_table[inbuf[inbuf_pos]] != bad_value || inbuf[inbuf_pos] == '=') &&
status != 0)
++inbuf_pos;
} // while (status != 0)
if (inbuf_pos != 0)
{
ostringstream sout;
sout << inbuf_pos << " extra characters were found at the end of the encoded data."
<< " This may indicate that the data stream has been truncated.";
// this happens if we hit EOF in the middle of decoding a 24bit block.
throw decode_error(sout.str());
}
// make sure the stream buffer flushes to its I/O channel
out.pubsync();
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BASE64_KERNEL_1_CPp_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BASE64_KERNEl_1_
#define DLIB_BASE64_KERNEl_1_
#include "../algs.h"
#include "base64_kernel_abstract.h"
#include <iosfwd>
namespace dlib
{
class base64
{
/*!
INITIAL VALUE
- bad_value == 100
- encode_table == a pointer to an array of 64 chars
- where x is a 6 bit value the following is true:
- encode_table[x] == the base64 encoding of x
- decode_table == a pointer to an array of UCHAR_MAX chars
- where x is any char value:
- if (x is a valid character in the base64 coding scheme) then
- decode_table[x] == the 6 bit value that x encodes
- else
- decode_table[x] == bad_value
CONVENTION
- The state of this object never changes so just refer to its
initial value.
!*/
public:
// this is here for backwards compatibility with older versions of dlib.
typedef base64 kernel_1a;
class decode_error : public dlib::error { public:
decode_error( const std::string& e) : error(e) {}};
base64 (
);
virtual ~base64 (
);
enum line_ending_type
{
CR, // i.e. "\r"
LF, // i.e. "\n"
CRLF // i.e. "\r\n"
};
line_ending_type line_ending (
) const;
void set_line_ending (
line_ending_type eol_style_
);
void encode (
std::istream& in,
std::ostream& out
) const;
void decode (
std::istream& in,
std::ostream& out
) const;
private:
char* encode_table;
unsigned char* decode_table;
const unsigned char bad_value;
line_ending_type eol_style;
// restricted functions
base64(base64&); // copy constructor
base64& operator=(base64&); // assignment operator
};
}
#ifdef NO_MAKEFILE
#include "base64_kernel_1.cpp"
#endif
#endif // DLIB_BASE64_KERNEl_1_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_BASE64_KERNEl_ABSTRACT_
#ifdef DLIB_BASE64_KERNEl_ABSTRACT_
#include "../algs.h"
#include <iosfwd>
namespace dlib
{
class base64
{
/*!
INITIAL VALUE
- line_ending() == LF
WHAT THIS OBJECT REPRESENTS
This object consists of the two functions encode and decode.
These functions allow you to encode and decode data to and from
the Base64 Content-Transfer-Encoding defined in section 6.8 of
rfc2045.
!*/
public:
class decode_error : public dlib::error {};
base64 (
);
/*!
ensures
- #*this is properly initialized
throws
- std::bad_alloc
!*/
virtual ~base64 (
);
/*!
ensures
- all memory associated with *this has been released
!*/
enum line_ending_type
{
CR, // i.e. "\r"
LF, // i.e. "\n"
CRLF // i.e. "\r\n"
};
line_ending_type line_ending (
) const;
/*!
ensures
- returns the type of end of line bytes the encoder
will use when encoding data to base64 blocks. Note that
the ostream object you use might apply some sort of transform
to line endings as well. For example, C++ ofstream objects
usually convert '\n' into whatever a normal newline is for
your platform unless you open a file in binary mode. But
aside from file streams the ostream objects usually don't
modify the data you pass to them.
!*/
void set_line_ending (
line_ending_type eol_style
);
/*!
ensures
- #line_ending() == eol_style
!*/
void encode (
std::istream& in,
std::ostream& out
) const;
/*!
ensures
- reads all data from in (until EOF is reached) and encodes it
and writes it to out
throws
- std::ios_base::failure
if there was a problem writing to out then this exception will
be thrown.
- any other exception
this exception may be thrown if there is any other problem
!*/
void decode (
std::istream& in,
std::ostream& out
) const;
/*!
ensures
- reads data from in (until EOF is reached), decodes it,
and writes it to out.
throws
- std::ios_base::failure
if there was a problem writing to out then this exception will
be thrown.
- decode_error
if an error was detected in the encoded data that prevented
it from being correctly decoded then this exception is
thrown.
- any other exception
this exception may be thrown if there is any other problem
!*/
private:
// restricted functions
base64(base64&); // copy constructor
base64& operator=(base64&); // assignment operator
};
}
#endif // DLIB_BASE64_KERNEl_ABSTRACT_
// Copyright (C) 2007 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BAYES_UTILs_H_
#define DLIB_BAYES_UTILs_H_
#include "bayes_utils/bayes_utils.h"
#endif // DLIB_BAYES_UTILs_H_
// Copyright (C) 2007 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BAYES_UTILs_
#define DLIB_BAYES_UTILs_
#include "bayes_utils_abstract.h"
#include <algorithm>
#include <ctime>
#include <memory>
#include <vector>
#include "../string.h"
#include "../map.h"
#include "../matrix.h"
#include "../rand.h"
#include "../array.h"
#include "../set.h"
#include "../algs.h"
#include "../noncopyable.h"
#include "../graph.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class assignment
{
public:
assignment()
{
}
assignment(
const assignment& a
)
{
a.reset();
while (a.move_next())
{
unsigned long idx = a.element().key();
unsigned long value = a.element().value();
vals.add(idx,value);
}
}
assignment& operator = (
const assignment& rhs
)
{
if (this == &rhs)
return *this;
assignment(rhs).swap(*this);
return *this;
}
void clear()
{
vals.clear();
}
bool operator < (
const assignment& item
) const
{
if (size() < item.size())
return true;
else if (size() > item.size())
return false;
reset();
item.reset();
while (move_next())
{
item.move_next();
if (element().key() < item.element().key())
return true;
else if (element().key() > item.element().key())
return false;
else if (element().value() < item.element().value())
return true;
else if (element().value() > item.element().value())
return false;
}
return false;
}
bool has_index (
unsigned long idx
) const
{
return vals.is_in_domain(idx);
}
void add (
unsigned long idx,
unsigned long value = 0
)
{
// make sure requires clause is not broken
DLIB_ASSERT( has_index(idx) == false ,
"\tvoid assignment::add(idx)"
<< "\n\tYou can't add the same index to an assignment object more than once"
<< "\n\tidx: " << idx
<< "\n\tthis: " << this
);
vals.add(idx, value);
}
unsigned long& operator[] (
const long idx
)
{
// make sure requires clause is not broken
DLIB_ASSERT( has_index(idx) == true ,
"\tunsigned long assignment::operator[](idx)"
<< "\n\tYou can't access an index value if it isn't already in the object"
<< "\n\tidx: " << idx
<< "\n\tthis: " << this
);
return vals[idx];
}
const unsigned long& operator[] (
const long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT( has_index(idx) == true ,
"\tunsigned long assignment::operator[](idx)"
<< "\n\tYou can't access an index value if it isn't already in the object"
<< "\n\tidx: " << idx
<< "\n\tthis: " << this
);
return vals[idx];
}
void swap (
assignment& item
)
{
vals.swap(item.vals);
}
void remove (
unsigned long idx
)
{
// make sure requires clause is not broken
DLIB_ASSERT( has_index(idx) == true ,
"\tunsigned long assignment::remove(idx)"
<< "\n\tYou can't remove an index value if it isn't already in the object"
<< "\n\tidx: " << idx
<< "\n\tthis: " << this
);
vals.destroy(idx);
}
unsigned long size() const { return vals.size(); }
void reset() const { vals.reset(); }
bool move_next() const { return vals.move_next(); }
map_pair<unsigned long, unsigned long>& element()
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tmap_pair<unsigned long,unsigned long>& assignment::element()"
<< "\n\tyou can't access the current element if it doesn't exist"
<< "\n\tthis: " << this
);
return vals.element();
}
const map_pair<unsigned long, unsigned long>& element() const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tconst map_pair<unsigned long,unsigned long>& assignment::element() const"
<< "\n\tyou can't access the current element if it doesn't exist"
<< "\n\tthis: " << this
);
return vals.element();
}
bool at_start() const { return vals.at_start(); }
bool current_element_valid() const { return vals.current_element_valid(); }
friend inline void serialize (
const assignment& item,
std::ostream& out
)
{
serialize(item.vals, out);
}
friend inline void deserialize (
assignment& item,
std::istream& in
)
{
deserialize(item.vals, in);
}
private:
mutable dlib::map<unsigned long, unsigned long>::kernel_1b_c vals;
};
inline std::ostream& operator << (
std::ostream& out,
const assignment& a
)
{
a.reset();
out << "(";
if (a.move_next())
out << a.element().key() << ":" << a.element().value();
while (a.move_next())
{
out << ", " << a.element().key() << ":" << a.element().value();
}
out << ")";
return out;
}
inline void swap (
assignment& a,
assignment& b
)
{
a.swap(b);
}
// ------------------------------------------------------------------------
class joint_probability_table
{
/*!
INITIAL VALUE
- table.size() == 0
CONVENTION
- size() == table.size()
- probability(a) == table[a]
!*/
public:
joint_probability_table (
const joint_probability_table& t
)
{
t.reset();
while (t.move_next())
{
assignment a = t.element().key();
double p = t.element().value();
set_probability(a,p);
}
}
joint_probability_table() {}
joint_probability_table& operator= (
const joint_probability_table& rhs
)
{
if (this == &rhs)
return *this;
joint_probability_table(rhs).swap(*this);
return *this;
}
void set_probability (
const assignment& a,
double p
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0.0 <= p && p <= 1.0,
"\tvoid& joint_probability_table::set_probability(a,p)"
<< "\n\tyou have given an invalid probability value"
<< "\n\tp: " << p
<< "\n\ta: " << a
<< "\n\tthis: " << this
);
if (table.is_in_domain(a))
{
table[a] = p;
}
else
{
assignment temp(a);
table.add(temp,p);
}
}
bool has_entry_for (
const assignment& a
) const
{
return table.is_in_domain(a);
}
void add_probability (
const assignment& a,
double p
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0.0 <= p && p <= 1.0,
"\tvoid& joint_probability_table::add_probability(a,p)"
<< "\n\tyou have given an invalid probability value"
<< "\n\tp: " << p
<< "\n\ta: " << a
<< "\n\tthis: " << this
);
if (table.is_in_domain(a))
{
table[a] += p;
if (table[a] > 1.0)
table[a] = 1.0;
}
else
{
assignment temp(a);
table.add(temp,p);
}
}
double probability (
const assignment& a
) const
{
return table[a];
}
void clear()
{
table.clear();
}
size_t size () const { return table.size(); }
bool move_next() const { return table.move_next(); }
void reset() const { table.reset(); }
map_pair<assignment,double>& element()
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tmap_pair<assignment,double>& joint_probability_table::element()"
<< "\n\tyou can't access the current element if it doesn't exist"
<< "\n\tthis: " << this
);
return table.element();
}
const map_pair<assignment,double>& element() const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tconst map_pair<assignment,double>& joint_probability_table::element() const"
<< "\n\tyou can't access the current element if it doesn't exist"
<< "\n\tthis: " << this
);
return table.element();
}
bool at_start() const { return table.at_start(); }
bool current_element_valid() const { return table.current_element_valid(); }
template <typename T>
void marginalize (
const T& vars,
joint_probability_table& out
) const
{
out.clear();
double p;
reset();
while (move_next())
{
assignment a;
const assignment& asrc = element().key();
p = element().value();
asrc.reset();
while (asrc.move_next())
{
if (vars.is_member(asrc.element().key()))
a.add(asrc.element().key(), asrc.element().value());
}
out.add_probability(a,p);
}
}
void marginalize (
const unsigned long var,
joint_probability_table& out
) const
{
out.clear();
double p;
reset();
while (move_next())
{
assignment a;
const assignment& asrc = element().key();
p = element().value();
asrc.reset();
while (asrc.move_next())
{
if (var == asrc.element().key())
a.add(asrc.element().key(), asrc.element().value());
}
out.add_probability(a,p);
}
}
void normalize (
)
{
double sum = 0;
reset();
while (move_next())
sum += element().value();
reset();
while (move_next())
element().value() /= sum;
}
void swap (
joint_probability_table& item
)
{
table.swap(item.table);
}
friend inline void serialize (
const joint_probability_table& item,
std::ostream& out
)
{
serialize(item.table, out);
}
friend inline void deserialize (
joint_probability_table& item,
std::istream& in
)
{
deserialize(item.table, in);
}
private:
dlib::map<assignment, double >::kernel_1b_c table;
};
inline void swap (
joint_probability_table& a,
joint_probability_table& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
class conditional_probability_table : noncopyable
{
/*!
INITIAL VALUE
- table.size() == 0
CONVENTION
- if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) then
- has_entry_for(value,ps) == true
- probability(value,ps) == table[ps](value)
- else
- has_entry_for(value,ps) == false
- num_values() == num_vals
!*/
public:
conditional_probability_table()
{
clear();
}
void set_num_values (
unsigned long num
)
{
num_vals = num;
table.clear();
}
bool has_entry_for (
unsigned long value,
const assignment& ps
) const
{
if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0)
return true;
else
return false;
}
unsigned long num_values (
) const { return num_vals; }
void set_probability (
unsigned long value,
const assignment& ps,
double p
)
{
// make sure requires clause is not broken
DLIB_ASSERT( value < num_values() && 0.0 <= p && p <= 1.0 ,
"\tvoid conditional_probability_table::set_probability()"
<< "\n\tinvalid arguments to set_probability"
<< "\n\tvalue: " << value
<< "\n\tnum_values(): " << num_values()
<< "\n\tp: " << p
<< "\n\tps: " << ps
<< "\n\tthis: " << this
);
if (table.is_in_domain(ps))
{
table[ps](value) = p;
}
else
{
matrix<double,1> dist(num_vals);
set_all_elements(dist,-1);
dist(value) = p;
assignment temp(ps);
table.add(temp,dist);
}
}
double probability(
unsigned long value,
const assignment& ps
) const
{
// make sure requires clause is not broken
DLIB_ASSERT( value < num_values() && has_entry_for(value,ps) ,
"\tvoid conditional_probability_table::probability()"
<< "\n\tinvalid arguments to probability"
<< "\n\tvalue: " << value
<< "\n\tnum_values(): " << num_values()
<< "\n\tps: " << ps
<< "\n\tthis: " << this
);
return table[ps](value);
}
void clear()
{
table.clear();
num_vals = 0;
}
void empty_table ()
{
table.clear();
}
void swap (
conditional_probability_table& item
)
{
exchange(num_vals, item.num_vals);
table.swap(item.table);
}
friend inline void serialize (
const conditional_probability_table& item,
std::ostream& out
)
{
serialize(item.table, out);
serialize(item.num_vals, out);
}
friend inline void deserialize (
conditional_probability_table& item,
std::istream& in
)
{
deserialize(item.table, in);
deserialize(item.num_vals, in);
}
private:
dlib::map<assignment, matrix<double,1> >::kernel_1b_c table;
unsigned long num_vals;
};
inline void swap (
conditional_probability_table& a,
conditional_probability_table& b
) { a.swap(b); }
// ------------------------------------------------------------------------
class bayes_node : noncopyable
{
public:
bayes_node ()
{
is_instantiated = false;
value_ = 0;
}
unsigned long value (
) const { return value_;}
void set_value (
unsigned long new_value
)
{
// make sure requires clause is not broken
DLIB_ASSERT( new_value < table().num_values(),
"\tvoid bayes_node::set_value(new_value)"
<< "\n\tnew_value must be less than the number of possible values for this node"
<< "\n\tnew_value: " << new_value
<< "\n\ttable().num_values(): " << table().num_values()
<< "\n\tthis: " << this
);
value_ = new_value;
}
conditional_probability_table& table (
) { return table_; }
const conditional_probability_table& table (
) const { return table_; }
bool is_evidence (
) const { return is_instantiated; }
void set_as_nonevidence (
) { is_instantiated = false; }
void set_as_evidence (
) { is_instantiated = true; }
void swap (
bayes_node& item
)
{
exchange(value_, item.value_);
exchange(is_instantiated, item.is_instantiated);
table_.swap(item.table_);
}
friend inline void serialize (
const bayes_node& item,
std::ostream& out
)
{
serialize(item.value_, out);
serialize(item.is_instantiated, out);
serialize(item.table_, out);
}
friend inline void deserialize (
bayes_node& item,
std::istream& in
)
{
deserialize(item.value_, in);
deserialize(item.is_instantiated, in);
deserialize(item.table_, in);
}
private:
unsigned long value_;
bool is_instantiated;
conditional_probability_table table_;
};
inline void swap (
bayes_node& a,
bayes_node& b
) { a.swap(b); }
// ------------------------------------------------------------------------
namespace bayes_node_utils
{
template <typename T>
unsigned long node_num_values (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tvoid bayes_node_utils::node_num_values(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
return bn.node(n).data.table().num_values();
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_value (
T& bn,
unsigned long n,
unsigned long val
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes() && val < node_num_values(bn,n),
"\tvoid bayes_node_utils::set_node_value(bn, n, val)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tval: " << val
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
);
bn.node(n).data.set_value(val);
}
// ----------------------------------------------------------------------------------------
template <typename T>
unsigned long node_value (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tunsigned long bayes_node_utils::node_value(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
return bn.node(n).data.value();
}
// ----------------------------------------------------------------------------------------
template <typename T>
bool node_is_evidence (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tbool bayes_node_utils::node_is_evidence(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
return bn.node(n).data.is_evidence();
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_as_evidence (
T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tvoid bayes_node_utils::set_node_as_evidence(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
bn.node(n).data.set_as_evidence();
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_as_nonevidence (
T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tvoid bayes_node_utils::set_node_as_nonevidence(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
bn.node(n).data.set_as_nonevidence();
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_num_values (
T& bn,
unsigned long n,
unsigned long num
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tvoid bayes_node_utils::set_node_num_values(bn, n, num)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
bn.node(n).data.table().set_num_values(num);
}
// ----------------------------------------------------------------------------------------
template <typename T>
double node_probability (
const T& bn,
unsigned long n,
unsigned long value,
const assignment& parents
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n),
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tvalue: " << value
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
);
DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(),
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tparents.size(): " << parents.size()
<< "\n\tb.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
);
#ifdef ENABLE_ASSERTS
parents.reset();
while (parents.move_next())
{
const unsigned long x = parents.element().key();
DLIB_ASSERT( bn.has_edge(x, n),
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
);
DLIB_ASSERT( parents[x] < node_num_values(bn,x),
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
<< "\n\tparents[x]: " << parents[x]
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
);
}
#endif
return bn.node(n).data.table().probability(value, parents);
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_probability (
T& bn,
unsigned long n,
unsigned long value,
const assignment& parents,
double p
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n),
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tp: " << p
<< "\n\tvalue: " << value
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
);
DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(),
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tp: " << p
<< "\n\tparents.size(): " << parents.size()
<< "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
);
DLIB_ASSERT( 0.0 <= p && p <= 1.0,
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tp: " << p
);
#ifdef ENABLE_ASSERTS
parents.reset();
while (parents.move_next())
{
const unsigned long x = parents.element().key();
DLIB_ASSERT( bn.has_edge(x, n),
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
);
DLIB_ASSERT( parents[x] < node_num_values(bn,x),
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
<< "\n\tparents[x]: " << parents[x]
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
);
}
#endif
bn.node(n).data.table().set_probability(value,parents,p);
}
// ----------------------------------------------------------------------------------------
template <typename T>
const assignment node_first_parent_assignment (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tconst assignment bayes_node_utils::node_first_parent_assignment(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
);
assignment a;
const unsigned long num_parents = bn.node(n).number_of_parents();
for (unsigned long i = 0; i < num_parents; ++i)
{
a.add(bn.node(n).parent(i).index(), 0);
}
return a;
}
// ----------------------------------------------------------------------------------------
template <typename T>
bool node_next_parent_assignment (
const T& bn,
unsigned long n,
assignment& a
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
);
DLIB_ASSERT( a.size() == bn.node(n).number_of_parents(),
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\ta.size(): " << a.size()
<< "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
);
#ifdef ENABLE_ASSERTS
a.reset();
while (a.move_next())
{
const unsigned long x = a.element().key();
DLIB_ASSERT( bn.has_edge(x, n),
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
);
DLIB_ASSERT( a[x] < node_num_values(bn,x),
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
<< "\n\ta[x]: " << a[x]
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
);
}
#endif
// basically this loop just adds 1 to the assignment but performs
// carries if necessary
for (unsigned long p = 0; p < a.size(); ++p)
{
const unsigned long pindex = bn.node(n).parent(p).index();
a[pindex] += 1;
// if we need to perform a carry
if (a[pindex] >= node_num_values(bn,pindex))
{
a[pindex] = 0;
}
else
{
// no carry necessary so we are done
return true;
}
}
// we got through the entire loop which means a carry propagated all the way out
// so there must not be any more valid assignments left
return false;
}
// ----------------------------------------------------------------------------------------
template <typename T>
bool node_cpt_filled_out (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tbool bayes_node_utils::node_cpt_filled_out(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
const unsigned long num_values = node_num_values(bn,n);
const conditional_probability_table& table = bn.node(n).data.table();
// now loop over all the possible parent assignments for this node
assignment a(node_first_parent_assignment(bn,n));
do
{
double sum = 0;
// make sure that this assignment has an entry for all the values this node can take one
for (unsigned long value = 0; value < num_values; ++value)
{
if (table.has_entry_for(value,a) == false)
return false;
else
sum += table.probability(value,a);
}
// check if the sum of probabilities equals 1 as it should
if (std::abs(sum-1.0) > 1e-5)
return false;
} while (node_next_parent_assignment(bn,n,a));
return true;
}
}
// ----------------------------------------------------------------------------------------
class bayesian_network_gibbs_sampler : noncopyable
{
public:
bayesian_network_gibbs_sampler ()
{
rnd.set_seed(cast_to_string(std::time(0)));
}
template <
typename T
>
void sample_graph (
T& bn
)
{
using namespace bayes_node_utils;
for (unsigned long n = 0; n < bn.number_of_nodes(); ++n)
{
if (node_is_evidence(bn, n))
continue;
samples.set_size(node_num_values(bn,n));
// obtain the probability distribution for this node
for (long i = 0; i < samples.nc(); ++i)
{
set_node_value(bn, n, i);
samples(i) = node_probability(bn, n);
for (unsigned long j = 0; j < bn.node(n).number_of_children(); ++j)
samples(i) *= node_probability(bn, bn.node(n).child(j).index());
}
//normalize samples
samples /= sum(samples);
// select a random point in the probability distribution
double prob = rnd.get_random_double();
// now find the point in the distribution this probability corresponds to
long j;
for (j = 0; j < samples.nc()-1; ++j)
{
if (prob <= samples(j))
break;
else
prob -= samples(j);
}
set_node_value(bn, n, j);
}
}
private:
template <
typename T
>
double node_probability (
const T& bn,
unsigned long n
)
/*!
requires
- n < bn.number_of_nodes()
ensures
- computes the probability of node n having its current value given
the current values of its parents in the network bn
!*/
{
v.clear();
for (unsigned long i = 0; i < bn.node(n).number_of_parents(); ++i)
{
v.add(bn.node(n).parent(i).index(), bn.node(n).parent(i).data.value());
}
return bn.node(n).data.table().probability(bn.node(n).data.value(), v);
}
assignment v;
dlib::rand rnd;
matrix<double,1> samples;
};
// ----------------------------------------------------------------------------------------
namespace bayesian_network_join_tree_helpers
{
class bnjt
{
/*!
this object is the base class used in this pimpl idiom
!*/
public:
virtual ~bnjt() {}
virtual const matrix<double,1> probability(
unsigned long idx
) const = 0;
};
template <typename T, typename U>
class bnjt_impl : public bnjt
{
/*!
This object is the implementation in the pimpl idiom
!*/
public:
bnjt_impl (
const T& bn,
const U& join_tree
)
{
create_bayesian_network_join_tree(bn, join_tree, join_tree_values);
cliques.resize(bn.number_of_nodes());
// figure out which cliques contain each node
for (unsigned long i = 0; i < cliques.size(); ++i)
{
// find the smallest clique that contains node with index i
unsigned long smallest_clique = 0;
unsigned long size = std::numeric_limits<unsigned long>::max();
for (unsigned long n = 0; n < join_tree.number_of_nodes(); ++n)
{
if (join_tree.node(n).data.is_member(i) && join_tree.node(n).data.size() < size)
{
size = join_tree.node(n).data.size();
smallest_clique = n;
}
}
cliques[i] = smallest_clique;
}
}
virtual const matrix<double,1> probability(
unsigned long idx
) const
{
join_tree_values.node(cliques[idx]).data.marginalize(idx, table);
table.normalize();
var.clear();
var.add(idx);
dist.set_size(table.size());
// read the probabilities out of the table and into the row matrix
for (unsigned long i = 0; i < table.size(); ++i)
{
var[idx] = i;
dist(i) = table.probability(var);
}
return dist;
}
private:
graph< joint_probability_table, joint_probability_table >::kernel_1a_c join_tree_values;
array<unsigned long> cliques;
mutable joint_probability_table table;
mutable assignment var;
mutable matrix<double,1> dist;
// ----------------------------------------------------------------------------------------
template <typename set_type, typename node_type>
bool set_contains_all_parents_of_node (
const set_type& set,
const node_type& node
)
{
for (unsigned long i = 0; i < node.number_of_parents(); ++i)
{
if (set.is_member(node.parent(i).index()) == false)
return false;
}
return true;
}
// ----------------------------------------------------------------------------------------
template <
typename V
>
void pass_join_tree_message (
const U& join_tree,
V& bn_join_tree ,
unsigned long from,
unsigned long to
)
{
using namespace bayes_node_utils;
const typename U::edge_type& e = edge(join_tree, from, to);
typename V::edge_type& old_s = edge(bn_join_tree, from, to);
typedef typename V::edge_type joint_prob_table;
joint_prob_table new_s;
bn_join_tree.node(from).data.marginalize(e, new_s);
joint_probability_table temp(new_s);
// divide new_s by old_s and store the result in temp.
// if old_s is empty then that is the same as if it was all 1s
// so we don't have to do this if that is the case.
if (old_s.size() > 0)
{
temp.reset();
old_s.reset();
while (temp.move_next())
{
old_s.move_next();
if (old_s.element().value() != 0)
temp.element().value() /= old_s.element().value();
}
}
// now multiply temp by d and store the results in d
joint_probability_table& d = bn_join_tree.node(to).data;
d.reset();
while (d.move_next())
{
assignment a;
const assignment& asrc = d.element().key();
asrc.reset();
while (asrc.move_next())
{
if (e.is_member(asrc.element().key()))
a.add(asrc.element().key(), asrc.element().value());
}
d.element().value() *= temp.probability(a);
}
// store new_s in old_s
new_s.swap(old_s);
}
// ----------------------------------------------------------------------------------------
template <
typename V
>
void create_bayesian_network_join_tree (
const T& bn,
const U& join_tree,
V& bn_join_tree
)
/*!
requires
- bn is a proper bayesian network
- join_tree is the join tree for that bayesian network
ensures
- bn_join_tree == the output of the join tree algorithm for bayesian network inference.
So each node in this graph contains a joint_probability_table for the clique
in the corresponding node in the join_tree graph.
!*/
{
using namespace bayes_node_utils;
bn_join_tree.clear();
copy_graph_structure(join_tree, bn_join_tree);
// we need to keep track of which node is "in" each clique for the purposes of
// initializing the tables in each clique. So this vector will be used to do that
// and a value of join_tree.number_of_nodes() means that the node with
// that index is unassigned.
std::vector<unsigned long> node_assigned_to(bn.number_of_nodes(),join_tree.number_of_nodes());
// populate evidence with all the evidence node indices and their values
dlib::map<unsigned long, unsigned long>::kernel_1b_c evidence;
for (unsigned long i = 0; i < bn.number_of_nodes(); ++i)
{
if (node_is_evidence(bn, i))
{
unsigned long idx = i;
unsigned long value = node_value(bn, i);
evidence.add(idx,value);
}
}
// initialize the bn join tree
for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i)
{
bool contains_evidence = false;
std::vector<unsigned long> indices;
assignment value;
// loop over all the nodes in this clique in the join tree. In this loop
// we are making an assignment with all the values of the nodes it represents set to 0
join_tree.node(i).data.reset();
while (join_tree.node(i).data.move_next())
{
const unsigned long idx = join_tree.node(i).data.element();
indices.push_back(idx);
value.add(idx);
if (evidence.is_in_domain(join_tree.node(i).data.element()))
contains_evidence = true;
}
// now loop over all possible combinations of values that the nodes this
// clique in the join tree can take on. We do this by counting by one through all
// legal values
bool more_assignments = true;
while (more_assignments)
{
bn_join_tree.node(i).data.set_probability(value,1);
// account for any evidence
if (contains_evidence)
{
// loop over all the nodes in this cluster
for (unsigned long j = 0; j < indices.size(); ++j)
{
// if the current node is an evidence node
if (evidence.is_in_domain(indices[j]))
{
const unsigned long idx = indices[j];
const unsigned long evidence_value = evidence[idx];
if (value[idx] != evidence_value)
bn_join_tree.node(i).data.set_probability(value , 0);
}
}
}
// now check if any of the nodes in this cluster also have their parents in this cluster
join_tree.node(i).data.reset();
while (join_tree.node(i).data.move_next())
{
const unsigned long idx = join_tree.node(i).data.element();
// if this clique contains all the parents of this node and also hasn't
// been assigned to another clique
if (set_contains_all_parents_of_node(join_tree.node(i).data, bn.node(idx)) &&
(i == node_assigned_to[idx] || node_assigned_to[idx] == join_tree.number_of_nodes()) )
{
// note that this node is now assigned to this clique
node_assigned_to[idx] = i;
// node idx has all its parents in the cluster
assignment parent_values;
for (unsigned long j = 0; j < bn.node(idx).number_of_parents(); ++j)
{
const unsigned long pidx = bn.node(idx).parent(j).index();
parent_values.add(pidx, value[pidx]);
}
double temp = bn_join_tree.node(i).data.probability(value);
bn_join_tree.node(i).data.set_probability(value, temp * node_probability(bn, idx, value[idx], parent_values));
}
}
// now advance the value variable to its next possible state if there is one
more_assignments = false;
value.reset();
while (value.move_next())
{
value.element().value() += 1;
// if overflow
if (value.element().value() == node_num_values(bn, value.element().key()))
{
value.element().value() = 0;
}
else
{
more_assignments = true;
break;
}
}
} // end while (more_assignments)
}
// the tree is now initialized. Now all we need to do is perform the propagation and
// we are done
dlib::array<dlib::set<unsigned long>::compare_1b_c> remaining_msg_to_send;
dlib::array<dlib::set<unsigned long>::compare_1b_c> remaining_msg_to_receive;
remaining_msg_to_receive.resize(join_tree.number_of_nodes());
remaining_msg_to_send.resize(join_tree.number_of_nodes());
for (unsigned long i = 0; i < remaining_msg_to_receive.size(); ++i)
{
for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j)
{
const unsigned long idx = join_tree.node(i).neighbor(j).index();
unsigned long temp;
temp = idx; remaining_msg_to_receive[i].add(temp);
temp = idx; remaining_msg_to_send[i].add(temp);
}
}
// now remaining_msg_to_receive[i] contains all the nodes that node i hasn't yet received
// a message from.
// we will consider node 0 to be the root node.
bool message_sent = true;
while (message_sent)
{
message_sent = false;
for (unsigned long i = 1; i < remaining_msg_to_send.size(); ++i)
{
// if node i hasn't sent any messages but has received all but one then send a message to the one
// node who hasn't sent i a message
if (remaining_msg_to_send[i].size() == join_tree.node(i).number_of_neighbors() && remaining_msg_to_receive[i].size() == 1)
{
unsigned long to;
// get the last remaining thing from this set
remaining_msg_to_receive[i].remove_any(to);
// send the message
pass_join_tree_message(join_tree, bn_join_tree, i, to);
// record that we sent this message
remaining_msg_to_send[i].destroy(to);
remaining_msg_to_receive[to].destroy(i);
// put to back in since we still need to receive it
remaining_msg_to_receive[i].add(to);
message_sent = true;
}
else if (remaining_msg_to_receive[i].size() == 0 && remaining_msg_to_send[i].size() > 0)
{
unsigned long to;
remaining_msg_to_send[i].remove_any(to);
remaining_msg_to_receive[to].destroy(i);
pass_join_tree_message(join_tree, bn_join_tree, i, to);
message_sent = true;
}
}
if (remaining_msg_to_receive[0].size() == 0)
{
// send a message to all of the root nodes neighbors unless we have already sent out he messages
while (remaining_msg_to_send[0].size() > 0)
{
unsigned long to;
remaining_msg_to_send[0].remove_any(to);
remaining_msg_to_receive[to].destroy(0);
pass_join_tree_message(join_tree, bn_join_tree, 0, to);
message_sent = true;
}
}
}
}
};
}
class bayesian_network_join_tree : noncopyable
{
/*!
use the pimpl idiom to push the template arguments from the class level to the
constructor level
!*/
public:
template <
typename T,
typename U
>
bayesian_network_join_tree (
const T& bn,
const U& join_tree
)
{
// make sure requires clause is not broken
DLIB_ASSERT( bn.number_of_nodes() > 0 ,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid bayesian network"
<< "\n\tthis: " << this
);
DLIB_ASSERT( is_join_tree(bn, join_tree) == true ,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid join tree for the supplied bayesian network"
<< "\n\tthis: " << this
);
DLIB_ASSERT( graph_contains_length_one_cycle(bn) == false,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid bayesian network"
<< "\n\tthis: " << this
);
DLIB_ASSERT( graph_is_connected(bn) == true,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid bayesian network"
<< "\n\tthis: " << this
);
#ifdef ENABLE_ASSERTS
for (unsigned long i = 0; i < bn.number_of_nodes(); ++i)
{
DLIB_ASSERT(bayes_node_utils::node_cpt_filled_out(bn,i) == true,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid bayesian network. "
<< "\n\tYou must finish filling out the conditional_probability_table of node " << i
<< "\n\tthis: " << this
);
}
#endif
impl.reset(new bayesian_network_join_tree_helpers::bnjt_impl<T,U>(bn, join_tree));
num_nodes = bn.number_of_nodes();
}
const matrix<double,1> probability(
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT( idx < number_of_nodes() ,
"\tconst matrix<double,1> bayesian_network_join_tree::probability(idx)"
<< "\n\tYou have specified an invalid node index"
<< "\n\tidx: " << idx
<< "\n\tnumber_of_nodes(): " << number_of_nodes()
<< "\n\tthis: " << this
);
return impl->probability(idx);
}
unsigned long number_of_nodes (
) const { return num_nodes; }
void swap (
bayesian_network_join_tree& item
)
{
exchange(num_nodes, item.num_nodes);
impl.swap(item.impl);
}
private:
std::unique_ptr<bayesian_network_join_tree_helpers::bnjt> impl;
unsigned long num_nodes;
};
inline void swap (
bayesian_network_join_tree& a,
bayesian_network_join_tree& b
) { a.swap(b); }
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_BAYES_UTILs_
// Copyright (C) 2007 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_BAYES_UTILs_ABSTRACT_
#ifdef DLIB_BAYES_UTILs_ABSTRACT_
#include "../algs.h"
#include "../noncopyable.h"
#include "../interfaces/enumerable.h"
#include "../interfaces/map_pair.h"
#include "../serialize.h"
#include <iostream>
namespace dlib
{
// ----------------------------------------------------------------------------------------
class assignment : public enumerable<map_pair<unsigned long, unsigned long> >
{
/*!
INITIAL VALUE
- size() == 0
ENUMERATION ORDER
The enumerator will iterate over the entries in the assignment in
ascending order according to index values. (i.e. the elements are
enumerated in sorted order according to the value of their keys)
WHAT THIS OBJECT REPRESENTS
This object models an assignment of random variables to particular values.
It is used with the joint_probability_table and conditional_probability_table
objects to represent assignments of various random variables to actual values.
So for example, if you had a joint_probability_table that represented the
following table:
P(A = 0, B = 0) = 0.2
P(A = 0, B = 1) = 0.3
P(A = 1, B = 0) = 0.1
P(A = 1, B = 1) = 0.4
Also lets define an enum so we have concrete index numbers for A and B
enum { A = 0, B = 1};
Then you could query the value of P(A=1, B=0) as follows:
assignment a;
a.set(A, 1);
a.set(B, 0);
// and now it is the case that:
table.probability(a) == 0.1
a[A] == 1
a[B] == 0
Also note that when enumerating the elements of an assignment object
the key() refers to the index and the value() refers to the value at that
index. For example:
// assume a is an assignment object
a.reset();
while (a.move_next())
{
// in this loop it is always the case that:
// a[a.element().key()] == a.element().value()
}
!*/
public:
assignment(
);
/*!
ensures
- this object is properly initialized
!*/
assignment(
const assignment& a
);
/*!
ensures
- #*this is a copy of a
!*/
assignment& operator = (
const assignment& rhs
);
/*!
ensures
- #*this is a copy of rhs
- returns *this
!*/
void clear(
);
/*!
ensures
- this object has been returned to its initial value
!*/
bool operator < (
const assignment& item
) const;
/*!
ensures
- The exact functioning of this operator is undefined. The only guarantee
is that it establishes a total ordering on all possible assignment objects.
In other words, this operator makes it so that you can use assignment
objects in the associative containers but otherwise isn't of any
particular use.
!*/
bool has_index (
unsigned long idx
) const;
/*!
ensures
- if (this assignment object has an entry for index idx) then
- returns true
- else
- returns false
!*/
void add (
unsigned long idx,
unsigned long value = 0
);
/*!
requires
- has_index(idx) == false
ensures
- #has_index(idx) == true
- #(*this)[idx] == value
!*/
void remove (
unsigned long idx
);
/*!
requires
- has_index(idx) == true
ensures
- #has_index(idx) == false
!*/
unsigned long& operator[] (
const long idx
);
/*!
requires
- has_index(idx) == true
ensures
- returns a reference to the value associated with index idx
!*/
const unsigned long& operator[] (
const long idx
) const;
/*!
requires
- has_index(idx) == true
ensures
- returns a const reference to the value associated with index idx
!*/
void swap (
assignment& item
);
/*!
ensures
- swaps *this and item
!*/
};
inline void swap (
assignment& a,
assignment& b
) { a.swap(b); }
/*!
provides a global swap
!*/
std::ostream& operator << (
std::ostream& out,
const assignment& a
);
/*!
ensures
- writes a to the given output stream in the following format:
(index1:value1, index2:value2, ..., indexN:valueN)
!*/
void serialize (
const assignment& item,
std::ostream& out
);
/*!
provides serialization support
!*/
void deserialize (
assignment& item,
std::istream& in
);
/*!
provides deserialization support
!*/
// ------------------------------------------------------------------------
class joint_probability_table : public enumerable<map_pair<assignment, double> >
{
/*!
INITIAL VALUE
- size() == 0
ENUMERATION ORDER
The enumerator will iterate over the entries in the probability table
in no particular order but they will all be visited.
WHAT THIS OBJECT REPRESENTS
This object models a joint probability table. That is, it models
the function p(X). So this object models the probability of a particular
set of variables (referred to as X).
!*/
public:
joint_probability_table(
);
/*!
ensures
- this object is properly initialized
!*/
joint_probability_table (
const joint_probability_table& t
);
/*!
ensures
- this object is a copy of t
!*/
void clear(
);
/*!
ensures
- this object has its initial value
!*/
joint_probability_table& operator= (
const joint_probability_table& rhs
);
/*!
ensures
- this object is a copy of rhs
- returns a reference to *this
!*/
bool has_entry_for (
const assignment& a
) const;
/*!
ensures
- if (this joint_probability_table has an entry for p(X = a)) then
- returns true
- else
- returns false
!*/
void set_probability (
const assignment& a,
double p
);
/*!
requires
- 0 <= p <= 1
ensures
- if (has_entry_for(a) == false) then
- #size() == size() + 1
- #probability(a) == p
- #has_entry_for(a) == true
!*/
void add_probability (
const assignment& a,
double p
);
/*!
requires
- 0 <= p <= 1
ensures
- if (has_entry_for(a) == false) then
- #size() == size() + 1
- #probability(a) == p
- else
- #probability(a) == min(probability(a) + p, 1.0)
(i.e. does a saturating add)
- #has_entry_for(a) == true
!*/
const double probability (
const assignment& a
) const;
/*!
ensures
- returns the probability p(X == a)
!*/
template <
typename T
>
void marginalize (
const T& vars,
joint_probability_table& output_table
) const;
/*!
requires
- T is an implementation of set/set_kernel_abstract.h
ensures
- marginalizes *this by summing over all variables not in vars. The
result is stored in output_table.
!*/
void marginalize (
const unsigned long var,
joint_probability_table& output_table
) const;
/*!
ensures
- is identical to calling the above marginalize() function with a set
that contains only var. Or in other words, performs a marginalization
with just one variable var. So that output_table will contain a table giving
the marginal probability of var all by itself.
!*/
void normalize (
);
/*!
ensures
- let sum == the sum of all the probabilities in this table
- after normalize() has finished it will be the case that the sum of all
the entries in this table is 1.0. This is accomplished by dividing all
the entries by the sum described above.
!*/
void swap (
joint_probability_table& item
);
/*!
ensures
- swaps *this and item
!*/
};
inline void swap (
joint_probability_table& a,
joint_probability_table& b
) { a.swap(b); }
/*!
provides a global swap
!*/
void serialize (
const joint_probability_table& item,
std::ostream& out
);
/*!
provides serialization support
!*/
void deserialize (
joint_probability_table& item,
std::istream& in
);
/*!
provides deserialization support
!*/
// ----------------------------------------------------------------------------------------
class conditional_probability_table : noncopyable
{
/*!
INITIAL VALUE
- num_values() == 0
- has_value_for(x, y) == false for all values of x and y
WHAT THIS OBJECT REPRESENTS
This object models a conditional probability table. That is, it models
the function p( X | parents). So this object models the conditional
probability of a particular variable (referred to as X) given another set
of variables (referred to as parents).
!*/
public:
conditional_probability_table(
);
/*!
ensures
- this object is properly initialized
!*/
void clear(
);
/*!
ensures
- this object has its initial value
!*/
void empty_table (
);
/*!
ensures
- for all possible v and p:
- #has_entry_for(v,p) == false
(i.e. this function clears out the table when you call it but doesn't
change the value of num_values())
!*/
void set_num_values (
unsigned long num
);
/*!
ensures
- #num_values() == num
- for all possible v and p:
- #has_entry_for(v,p) == false
(i.e. this function clears out the table when you call it)
!*/
unsigned long num_values (
) const;
/*!
ensures
- This object models the probability table p(X | parents). This
function returns the number of values X can take on.
!*/
bool has_entry_for (
unsigned long value,
const assignment& ps
) const;
/*!
ensures
- if (this conditional_probability_table has an entry for p(X = value, parents = ps)) then
- returns true
- else
- returns false
!*/
void set_probability (
unsigned long value,
const assignment& ps,
double p
);
/*!
requires
- value < num_values()
- 0 <= p <= 1
ensures
- #probability(ps, value) == p
- #has_entry_for(value, ps) == true
!*/
double probability(
unsigned long value,
const assignment& ps
) const;
/*!
requires
- value < num_values()
- has_entry_for(value, ps) == true
ensures
- returns the probability p( X = value | parents = ps).
!*/
void swap (
conditional_probability_table& item
);
/*!
ensures
- swaps *this and item
!*/
};
inline void swap (
conditional_probability_table& a,
conditional_probability_table& b
) { a.swap(b); }
/*!
provides a global swap
!*/
void serialize (
const conditional_probability_table& item,
std::ostream& out
);
/*!
provides serialization support
!*/
void deserialize (
conditional_probability_table& item,
std::istream& in
);
/*!
provides deserialization support
!*/
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
class bayes_node : noncopyable
{
/*!
INITIAL VALUE
- is_evidence() == false
- value() == 0
- table().num_values() == 0
WHAT THIS OBJECT REPRESENTS
This object represents a node in a bayesian network. It is
intended to be used inside the dlib::directed_graph object to
represent bayesian networks.
!*/
public:
bayes_node (
);
/*!
ensures
- this object is properly initialized
!*/
unsigned long value (
) const;
/*!
ensures
- returns the current value of this node
!*/
void set_value (
unsigned long new_value
);
/*!
requires
- new_value < table().num_values()
ensures
- #value() == new_value
!*/
conditional_probability_table& table (
);
/*!
ensures
- returns a reference to the conditional_probability_table associated with this node
!*/
const conditional_probability_table& table (
) const;
/*!
ensures
- returns a const reference to the conditional_probability_table associated with this
node.
!*/
bool is_evidence (
) const;
/*!
ensures
- if (this is an evidence node) then
- returns true
- else
- returns false
!*/
void set_as_nonevidence (
);
/*!
ensures
- #is_evidence() == false
!*/
void set_as_evidence (
);
/*!
ensures
- #is_evidence() == true
!*/
void swap (
bayes_node& item
);
/*!
ensures
- swaps *this and item
!*/
};
inline void swap (
bayes_node& a,
bayes_node& b
) { a.swap(b); }
/*!
provides a global swap
!*/
void serialize (
const bayes_node& item,
std::ostream& out
);
/*!
provides serialization support
!*/
void deserialize (
bayes_node& item,
std::istream& in
);
/*!
provides deserialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
/*
The following group of functions are convenience functions for manipulating
bayes_node objects while they are inside a directed_graph. These functions
also have additional requires clauses that, in debug mode, will protect you
from attempts to manipulate a bayesian network in an inappropriate way.
*/
namespace bayes_node_utils
{
template <
typename T
>
void set_node_value (
T& bn,
unsigned long n,
unsigned long val
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
- val < node_num_values(bn, n)
ensures
- #bn.node(n).data.value() = val
!*/
// ------------------------------------------------------------------------------------
template <
typename T
>
unsigned long node_value (
const T& bn,
unsigned long n
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
ensures
- returns bn.node(n).data.value()
!*/
// ------------------------------------------------------------------------------------
template <
typename T
>
bool node_is_evidence (
const T& bn,
unsigned long n
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
ensures
- returns bn.node(n).data.is_evidence()
!*/
// ------------------------------------------------------------------------------------
template <
typename T
>
void set_node_as_evidence (
T& bn,
unsigned long n
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
ensures
- executes: bn.node(n).data.set_as_evidence()
!*/
// ------------------------------------------------------------------------------------
template <
typename T
>
void set_node_as_nonevidence (
T& bn,
unsigned long n
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
ensures
- executes: bn.node(n).data.set_as_nonevidence()
!*/
// ------------------------------------------------------------------------------------
template <
typename T
>
void set_node_num_values (
T& bn,
unsigned long n,
unsigned long num
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
ensures
- #bn.node(n).data.table().num_values() == num
(i.e. sets the number of different values this node can take)
!*/
// ------------------------------------------------------------------------------------
template <
typename T
>
unsigned long node_num_values (
const T& bn,
unsigned long n
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
ensures
- returns bn.node(n).data.table().num_values()
(i.e. returns the number of different values this node can take)
!*/
// ------------------------------------------------------------------------------------
template <
typename T
>
const double node_probability (
const T& bn,
unsigned long n,
unsigned long value,
const assignment& parents
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
- value < node_num_values(bn,n)
- parents.size() == bn.node(n).number_of_parents()
- if (parents.has_index(x)) then
- bn.has_edge(x, n)
- parents[x] < node_num_values(bn,x)
ensures
- returns bn.node(n).data.table().probability(value, parents)
(i.e. returns the probability of node n having the given value when
its parents have the given assignment)
!*/
// ------------------------------------------------------------------------------------
template <
typename T
>
const double set_node_probability (
const T& bn,
unsigned long n,
unsigned long value,
const assignment& parents,
double p
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
- value < node_num_values(bn,n)
- 0 <= p <= 1
- parents.size() == bn.node(n).number_of_parents()
- if (parents.has_index(x)) then
- bn.has_edge(x, n)
- parents[x] < node_num_values(bn,x)
ensures
- #bn.node(n).data.table().probability(value, parents) == p
(i.e. sets the probability of node n having the given value when
its parents have the given assignment to the probability p)
!*/
// ------------------------------------------------------------------------------------
template <typename T>
const assignment node_first_parent_assignment (
const T& bn,
unsigned long n
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
ensures
- returns an assignment A such that:
- A.size() == bn.node(n).number_of_parents()
- if (P is a parent of bn.node(n)) then
- A.has_index(P)
- A[P] == 0
- I.e. this function returns an assignment that contains all
the parents of the given node. Also, all the values of each
parent in the assignment is set to zero.
!*/
// ------------------------------------------------------------------------------------
template <typename T>
bool node_next_parent_assignment (
const T& bn,
unsigned long n,
assignment& A
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
- A.size() == bn.node(n).number_of_parents()
- if (A.has_index(x)) then
- bn.has_edge(x, n)
- A[x] < node_num_values(bn,x)
ensures
- The behavior of this function is defined by the following code:
assignment a(node_first_parent_assignment(bn,n);
do {
// this loop loops over all possible parent assignments
// of the node bn.node(n). Each time through the loop variable a
// will be the next assignment.
} while (node_next_parent_assignment(bn,n,a))
!*/
// ------------------------------------------------------------------------------------
template <typename T>
bool node_cpt_filled_out (
const T& bn,
unsigned long n
);
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
- n < bn.number_of_nodes()
ensures
- if (the conditional_probability_table bn.node(n).data.table() is
fully filled out for this node) then
- returns true
- This means that each parent assignment for the given node
along with all possible values of this node shows up in the
table.
- It also means that all the probabilities conditioned on the
same parent assignment sum to 1.0
- else
- returns false
!*/
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class bayesian_network_gibbs_sampler : noncopyable
{
/*!
INITIAL VALUE
This object has no state
WHAT THIS OBJECT REPRESENTS
This object performs Markov Chain Monte Carlo sampling of a bayesian
network using the Gibbs sampling technique.
Note that this object is limited to only bayesian networks that
don't contain deterministic nodes. That is, incorrect results may
be computed if this object is used when the bayesian network contains
any nodes that have a probability of 1 in their conditional probability
tables for any event. So don't use this object for networks with
deterministic nodes.
!*/
public:
bayesian_network_gibbs_sampler (
);
/*!
ensures
- this object is properly initialized
!*/
template <
typename T
>
void sample_graph (
T& bn
)
/*!
requires
- T is an implementation of directed_graph/directed_graph_kernel_abstract.h
- T::type == bayes_node
ensures
- modifies randomly (via the Gibbs sampling technique) samples all the nodes
in the network and updates their values with the newly sampled values
!*/
};
// ----------------------------------------------------------------------------------------
class bayesian_network_join_tree : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents an implementation of the join tree algorithm
for inference in bayesian networks. It doesn't have any mutable state.
To you use you just give it a directed_graph that contains a bayesian
network and a graph object that contains that networks corresponding
join tree. Then you may query this object to determine the probabilities
of any variables in the original bayesian network.
!*/
public:
template <
typename bn_type,
typename join_tree_type
>
bayesian_network_join_tree (
const bn_type& bn,
const join_tree_type& join_tree
);
/*!
requires
- bn_type is an implementation of directed_graph/directed_graph_kernel_abstract.h
- bn_type::type == bayes_node
- join_tree_type is an implementation of graph/graph_kernel_abstract.h
- join_tree_type::type is an implementation of set/set_compare_abstract.h and
this set type contains unsigned long objects.
- join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and
this set type contains unsigned long objects.
- is_join_tree(bn, join_tree) == true
- bn == a valid bayesian network with all its conditional probability tables
filled out
- for all valid n:
- node_cpt_filled_out(bn,n) == true
- graph_contains_length_one_cycle(bn) == false
- graph_is_connected(bn) == true
- bn.number_of_nodes() > 0
ensures
- this object is properly initialized
!*/
unsigned long number_of_nodes (
) const;
/*!
ensures
- returns the number of nodes in the bayesian network that this
object was instantiated from.
!*/
const matrix<double,1> probability(
unsigned long idx
) const;
/*!
requires
- idx < number_of_nodes()
ensures
- returns the probability distribution for the node with index idx that was in the bayesian
network that *this was instantiated from. Let D represent this distribution, then:
- D.nc() == the number of values the node idx ranges over
- D.nr() == 1
- D(i) == the probability of node idx taking on the value i
!*/
void swap (
bayesian_network_join_tree& item
);
/*!
ensures
- swaps *this with item
!*/
};
inline void swap (
bayesian_network_join_tree& a,
bayesian_network_join_tree& b
) { a.swap(b); }
/*!
provides a global swap
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BAYES_UTILs_ABSTRACT_
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BIGINt_
#define DLIB_BIGINt_
#include "bigint/bigint_kernel_1.h"
#include "bigint/bigint_kernel_2.h"
#include "bigint/bigint_kernel_c.h"
namespace dlib
{
class bigint
{
bigint() {}
public:
//----------- kernels ---------------
// kernel_1a
typedef bigint_kernel_1
kernel_1a;
typedef bigint_kernel_c<kernel_1a>
kernel_1a_c;
// kernel_2a
typedef bigint_kernel_2
kernel_2a;
typedef bigint_kernel_c<kernel_2a>
kernel_2a_c;
};
}
#endif // DLIB_BIGINt_
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BIGINT_KERNEL_1_CPp_
#define DLIB_BIGINT_KERNEL_1_CPp_
#include "bigint_kernel_1.h"
#include <iostream>
namespace dlib
{
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// member/friend function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
bigint_kernel_1::
bigint_kernel_1 (
) :
slack(25),
data(new data_record(slack))
{}
// ----------------------------------------------------------------------------------------
bigint_kernel_1::
bigint_kernel_1 (
uint32 value
) :
slack(25),
data(new data_record(slack))
{
*(data->number) = static_cast<uint16>(value&0xFFFF);
*(data->number+1) = static_cast<uint16>((value>>16)&0xFFFF);
if (*(data->number+1) != 0)
data->digits_used = 2;
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1::
bigint_kernel_1 (
const bigint_kernel_1& item
) :
slack(25),
data(item.data)
{
data->references += 1;
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1::
~bigint_kernel_1 (
)
{
if (data->references == 1)
{
delete data;
}
else
{
data->references -= 1;
}
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 bigint_kernel_1::
operator+ (
const bigint_kernel_1& rhs
) const
{
data_record* temp = new data_record (
std::max(rhs.data->digits_used,data->digits_used) + slack
);
long_add(data,rhs.data,temp);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator+= (
const bigint_kernel_1& rhs
)
{
// if there are other references to our data
if (data->references != 1)
{
data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack);
data->references -= 1;
long_add(data,rhs.data,temp);
data = temp;
}
// if data is not big enough for the result
else if (data->size <= std::max(data->digits_used,rhs.data->digits_used))
{
data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack);
long_add(data,rhs.data,temp);
delete data;
data = temp;
}
// there is enough size and no references
else
{
long_add(data,rhs.data,data);
}
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 bigint_kernel_1::
operator- (
const bigint_kernel_1& rhs
) const
{
data_record* temp = new data_record (
data->digits_used + slack
);
long_sub(data,rhs.data,temp);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator-= (
const bigint_kernel_1& rhs
)
{
// if there are other references to this data
if (data->references != 1)
{
data_record* temp = new data_record(data->digits_used+slack);
data->references -= 1;
long_sub(data,rhs.data,temp);
data = temp;
}
else
{
long_sub(data,rhs.data,data);
}
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 bigint_kernel_1::
operator* (
const bigint_kernel_1& rhs
) const
{
data_record* temp = new data_record (
data->digits_used + rhs.data->digits_used + slack
);
long_mul(data,rhs.data,temp);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator*= (
const bigint_kernel_1& rhs
)
{
// create a data_record to store the result of the multiplication in
data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack);
long_mul(data,rhs.data,temp);
// if there are other references to data
if (data->references != 1)
{
data->references -= 1;
}
else
{
delete data;
}
data = temp;
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 bigint_kernel_1::
operator/ (
const bigint_kernel_1& rhs
) const
{
data_record* temp = new data_record(data->digits_used+slack);
data_record* remainder;
try {
remainder = new data_record(data->digits_used+slack);
} catch (...) { delete temp; throw; }
long_div(data,rhs.data,temp,remainder);
delete remainder;
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator/= (
const bigint_kernel_1& rhs
)
{
data_record* temp = new data_record(data->digits_used+slack);
data_record* remainder;
try {
remainder = new data_record(data->digits_used+slack);
} catch (...) { delete temp; throw; }
long_div(data,rhs.data,temp,remainder);
// check if there are other references to data
if (data->references != 1)
{
data->references -= 1;
}
// if there are no references to data then it must be deleted
else
{
delete data;
}
data = temp;
delete remainder;
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 bigint_kernel_1::
operator% (
const bigint_kernel_1& rhs
) const
{
data_record* temp = new data_record(data->digits_used+slack);
data_record* remainder;
try {
remainder = new data_record(data->digits_used+slack);
} catch (...) { delete temp; throw; }
long_div(data,rhs.data,temp,remainder);
delete temp;
return bigint_kernel_1(remainder,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator%= (
const bigint_kernel_1& rhs
)
{
data_record* temp = new data_record(data->digits_used+slack);
data_record* remainder;
try {
remainder = new data_record(data->digits_used+slack);
} catch (...) { delete temp; throw; }
long_div(data,rhs.data,temp,remainder);
// check if there are other references to data
if (data->references != 1)
{
data->references -= 1;
}
// if there are no references to data then it must be deleted
else
{
delete data;
}
data = remainder;
delete temp;
return *this;
}
// ----------------------------------------------------------------------------------------
bool bigint_kernel_1::
operator < (
const bigint_kernel_1& rhs
) const
{
return is_less_than(data,rhs.data);
}
// ----------------------------------------------------------------------------------------
bool bigint_kernel_1::
operator == (
const bigint_kernel_1& rhs
) const
{
return is_equal_to(data,rhs.data);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator= (
const bigint_kernel_1& rhs
)
{
if (this == &rhs)
return *this;
// if we have the only reference to our data then delete it
if (data->references == 1)
{
delete data;
data = rhs.data;
data->references += 1;
}
else
{
data->references -= 1;
data = rhs.data;
data->references += 1;
}
return *this;
}
// ----------------------------------------------------------------------------------------
std::ostream& operator<< (
std::ostream& out_,
const bigint_kernel_1& rhs
)
{
std::ostream out(out_.rdbuf());
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record(*rhs.data,0);
// get a char array big enough to hold the number in ascii format
char* str;
try {
str = new char[(rhs.data->digits_used)*5+10];
} catch (...) { delete temp; throw; }
char* str_start = str;
str += (rhs.data->digits_used)*5+9;
*str = 0; --str;
uint16 remainder;
rhs.short_div(temp,10000,temp,remainder);
// pull the digits out of remainder
char a = remainder % 10 + '0';
remainder /= 10;
char b = remainder % 10 + '0';
remainder /= 10;
char c = remainder % 10 + '0';
remainder /= 10;
char d = remainder % 10 + '0';
remainder /= 10;
*str = a; --str;
*str = b; --str;
*str = c; --str;
*str = d; --str;
// keep looping until temp represents zero
while (temp->digits_used != 1 || *(temp->number) != 0)
{
rhs.short_div(temp,10000,temp,remainder);
// pull the digits out of remainder
char a = remainder % 10 + '0';
remainder /= 10;
char b = remainder % 10 + '0';
remainder /= 10;
char c = remainder % 10 + '0';
remainder /= 10;
char d = remainder % 10 + '0';
remainder /= 10;
*str = a; --str;
*str = b; --str;
*str = c; --str;
*str = d; --str;
}
// throw away and extra leading zeros
++str;
if (*str == '0')
++str;
if (*str == '0')
++str;
if (*str == '0')
++str;
out << str;
delete [] str_start;
delete temp;
return out_;
}
// ----------------------------------------------------------------------------------------
std::istream& operator>> (
std::istream& in_,
bigint_kernel_1& rhs
)
{
std::istream in(in_.rdbuf());
// ignore any leading whitespaces
while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n')
{
in.get();
}
// if the first digit is not an integer then this is an error
if ( !(in.peek() >= '0' && in.peek() <= '9'))
{
in_.clear(std::ios::failbit);
return in_;
}
int num_read;
bigint_kernel_1 temp;
do
{
// try to get 4 chars from in
num_read = 1;
char a = 0;
char b = 0;
char c = 0;
char d = 0;
if (in.peek() >= '0' && in.peek() <= '9')
{
num_read *= 10;
a = in.get();
}
if (in.peek() >= '0' && in.peek() <= '9')
{
num_read *= 10;
b = in.get();
}
if (in.peek() >= '0' && in.peek() <= '9')
{
num_read *= 10;
c = in.get();
}
if (in.peek() >= '0' && in.peek() <= '9')
{
num_read *= 10;
d = in.get();
}
// merge the for digits into an uint16
uint16 num = 0;
if (a != 0)
{
num = a - '0';
}
if (b != 0)
{
num *= 10;
num += b - '0';
}
if (c != 0)
{
num *= 10;
num += c - '0';
}
if (d != 0)
{
num *= 10;
num += d - '0';
}
if (num_read != 1)
{
// shift the digits in temp left by the number of new digits we just read
temp *= num_read;
// add in new digits
temp += num;
}
} while (num_read == 10000);
rhs = temp;
return in_;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator+ (
uint16 lhs,
const bigint_kernel_1& rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record
(rhs.data->digits_used+rhs.slack);
rhs.short_add(rhs.data,lhs,temp);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator+ (
const bigint_kernel_1& lhs,
uint16 rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record
(lhs.data->digits_used+lhs.slack);
lhs.short_add(lhs.data,rhs,temp);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator+= (
uint16 rhs
)
{
// if there are other references to this data
if (data->references != 1)
{
data_record* temp = new data_record(data->digits_used+slack);
data->references -= 1;
short_add(data,rhs,temp);
data = temp;
}
// or if we need to enlarge data then do so
else if (data->digits_used == data->size)
{
data_record* temp = new data_record(data->digits_used+slack);
short_add(data,rhs,temp);
delete data;
data = temp;
}
// or if there is plenty of space and no references
else
{
short_add(data,rhs,data);
}
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator- (
uint16 lhs,
const bigint_kernel_1& rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record(rhs.slack);
*(temp->number) = lhs - *(rhs.data->number);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator- (
const bigint_kernel_1& lhs,
uint16 rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record
(lhs.data->digits_used+lhs.slack);
lhs.short_sub(lhs.data,rhs,temp);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator-= (
uint16 rhs
)
{
// if there are other references to this data
if (data->references != 1)
{
data_record* temp = new data_record(data->digits_used+slack);
data->references -= 1;
short_sub(data,rhs,temp);
data = temp;
}
else
{
short_sub(data,rhs,data);
}
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator* (
uint16 lhs,
const bigint_kernel_1& rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record
(rhs.data->digits_used+rhs.slack);
rhs.short_mul(rhs.data,lhs,temp);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator* (
const bigint_kernel_1& lhs,
uint16 rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record
(lhs.data->digits_used+lhs.slack);
lhs.short_mul(lhs.data,rhs,temp);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator*= (
uint16 rhs
)
{
// if there are other references to this data
if (data->references != 1)
{
data_record* temp = new data_record(data->digits_used+slack);
data->references -= 1;
short_mul(data,rhs,temp);
data = temp;
}
// or if we need to enlarge data
else if (data->digits_used == data->size)
{
data_record* temp = new data_record(data->digits_used+slack);
short_mul(data,rhs,temp);
delete data;
data = temp;
}
else
{
short_mul(data,rhs,data);
}
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator/ (
uint16 lhs,
const bigint_kernel_1& rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record(rhs.slack);
// if rhs might not be bigger than lhs
if (rhs.data->digits_used == 1)
{
*(temp->number) = lhs/ *(rhs.data->number);
}
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator/ (
const bigint_kernel_1& lhs,
uint16 rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record
(lhs.data->digits_used+lhs.slack);
uint16 remainder;
lhs.short_div(lhs.data,rhs,temp,remainder);
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator/= (
uint16 rhs
)
{
uint16 remainder;
// if there are other references to this data
if (data->references != 1)
{
data_record* temp = new data_record(data->digits_used+slack);
data->references -= 1;
short_div(data,rhs,temp,remainder);
data = temp;
}
else
{
short_div(data,rhs,data,remainder);
}
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator% (
uint16 lhs,
const bigint_kernel_1& rhs
)
{
typedef bigint_kernel_1 bigint;
// temp is zero by default
bigint::data_record* temp = new bigint::data_record(rhs.slack);
if (rhs.data->digits_used == 1)
{
// if rhs is just an uint16 inside then perform the modulus
*(temp->number) = lhs % *(rhs.data->number);
}
else
{
// if rhs is bigger than lhs then the answer is lhs
*(temp->number) = lhs;
}
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 operator% (
const bigint_kernel_1& lhs,
uint16 rhs
)
{
typedef bigint_kernel_1 bigint;
bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack);
uint16 remainder;
lhs.short_div(lhs.data,rhs,temp,remainder);
temp->digits_used = 1;
*(temp->number) = remainder;
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator%= (
uint16 rhs
)
{
uint16 remainder;
// if there are other references to this data
if (data->references != 1)
{
data_record* temp = new data_record(data->digits_used+slack);
data->references -= 1;
short_div(data,rhs,temp,remainder);
data = temp;
}
else
{
short_div(data,rhs,data,remainder);
}
data->digits_used = 1;
*(data->number) = remainder;
return *this;
}
// ----------------------------------------------------------------------------------------
bool operator < (
uint16 lhs,
const bigint_kernel_1& rhs
)
{
return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) );
}
// ----------------------------------------------------------------------------------------
bool operator < (
const bigint_kernel_1& lhs,
uint16 rhs
)
{
return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs);
}
// ----------------------------------------------------------------------------------------
bool operator == (
const bigint_kernel_1& lhs,
uint16 rhs
)
{
return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs);
}
// ----------------------------------------------------------------------------------------
bool operator == (
uint16 lhs,
const bigint_kernel_1& rhs
)
{
return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator= (
uint16 rhs
)
{
// check if there are other references to our data
if (data->references != 1)
{
data->references -= 1;
try {
data = new data_record(slack);
} catch (...) { data->references += 1; throw; }
}
else
{
data->digits_used = 1;
}
*(data->number) = rhs;
return *this;
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator++ (
)
{
// if there are other references to this data then make a copy of it
if (data->references != 1)
{
data_record* temp = new data_record(data->digits_used+slack);
data->references -= 1;
increment(data,temp);
data = temp;
}
// or if we need to enlarge data then do so
else if (data->digits_used == data->size)
{
data_record* temp = new data_record(data->digits_used+slack);
increment(data,temp);
delete data;
data = temp;
}
else
{
increment(data,data);
}
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 bigint_kernel_1::
operator++ (
int
)
{
data_record* temp; // this is the copy of temp we will return in the end
data_record* temp2 = new data_record(data->digits_used+slack);
increment(data,temp2);
temp = data;
data = temp2;
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
bigint_kernel_1& bigint_kernel_1::
operator-- (
)
{
// if there are other references to this data
if (data->references != 1)
{
data_record* temp = new data_record(data->digits_used+slack);
data->references -= 1;
decrement(data,temp);
data = temp;
}
else
{
decrement(data,data);
}
return *this;
}
// ----------------------------------------------------------------------------------------
const bigint_kernel_1 bigint_kernel_1::
operator-- (
int
)
{
data_record* temp; // this is the copy of temp we will return in the end
data_record* temp2 = new data_record(data->digits_used+slack);
decrement(data,temp2);
temp = data;
data = temp2;
return bigint_kernel_1(temp,0);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// private member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
short_add (
const data_record* data,
uint16 value,
data_record* result
) const
{
// put value into the carry part of temp
uint32 temp = value;
temp <<= 16;
const uint16* number = data->number;
const uint16* end = number + data->digits_used; // one past the end of number
uint16* r = result->number;
while (number != end)
{
// add *number and the current carry
temp = *number + (temp>>16);
// put the low word of temp into *r
*r = static_cast<uint16>(temp & 0xFFFF);
++number;
++r;
}
// if there is a final carry
if ((temp>>16) != 0)
{
result->digits_used = data->digits_used + 1;
// store the carry in the most significant digit of the result
*r = static_cast<uint16>(temp>>16);
}
else
{
result->digits_used = data->digits_used;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
short_sub (
const data_record* data,
uint16 value,
data_record* result
) const
{
const uint16* number = data->number;
const uint16* end = number + data->digits_used - 1;
uint16* r = result->number;
uint32 temp = *number - value;
// put the low word of temp into *data
*r = static_cast<uint16>(temp & 0xFFFF);
while (number != end)
{
++number;
++r;
// subtract the carry from *number
temp = *number - (temp>>31);
// put the low word of temp into *r
*r = static_cast<uint16>(temp & 0xFFFF);
}
// if we lost a digit in the subtraction
if (*r == 0)
{
if (data->digits_used == 1)
result->digits_used = 1;
else
result->digits_used = data->digits_used - 1;
}
else
{
result->digits_used = data->digits_used;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
short_mul (
const data_record* data,
uint16 value,
data_record* result
) const
{
uint32 temp = 0;
const uint16* number = data->number;
uint16* r = result->number;
const uint16* end = r + data->digits_used;
while ( r != end)
{
// multiply *data and value and add in the carry
temp = *number*(uint32)value + (temp>>16);
// put the low word of temp into *data
*r = static_cast<uint16>(temp & 0xFFFF);
++number;
++r;
}
// if there is a final carry
if ((temp>>16) != 0)
{
result->digits_used = data->digits_used + 1;
// put the final carry into the most significant digit of the result
*r = static_cast<uint16>(temp>>16);
}
else
{
result->digits_used = data->digits_used;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
short_div (
const data_record* data,
uint16 value,
data_record* result,
uint16& rem
) const
{
uint16 remainder = 0;
uint32 temp;
const uint16* number = data->number + data->digits_used - 1;
const uint16* end = number - data->digits_used;
uint16* r = result->number + data->digits_used - 1;
// if we are losing a digit in this division
if (*number < value)
{
if (data->digits_used == 1)
result->digits_used = 1;
else
result->digits_used = data->digits_used - 1;
}
else
{
result->digits_used = data->digits_used;
}
// perform the actual division
while (number != end)
{
temp = *number + (((uint32)remainder)<<16);
*r = static_cast<uint16>(temp/value);
remainder = static_cast<uint16>(temp%value);
--number;
--r;
}
rem = remainder;
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
long_add (
const data_record* lhs,
const data_record* rhs,
data_record* result
) const
{
// put value into the carry part of temp
uint32 temp=0;
uint16* min_num; // the number with the least digits used
uint16* max_num; // the number with the most digits used
uint16* min_end; // one past the end of min_num
uint16* max_end; // one past the end of max_num
uint16* r = result->number;
uint32 max_digits_used;
if (lhs->digits_used < rhs->digits_used)
{
max_digits_used = rhs->digits_used;
min_num = lhs->number;
max_num = rhs->number;
min_end = min_num + lhs->digits_used;
max_end = max_num + rhs->digits_used;
}
else
{
max_digits_used = lhs->digits_used;
min_num = rhs->number;
max_num = lhs->number;
min_end = min_num + rhs->digits_used;
max_end = max_num + lhs->digits_used;
}
while (min_num != min_end)
{
// add *min_num, *max_num and the current carry
temp = *min_num + *max_num + (temp>>16);
// put the low word of temp into *r
*r = static_cast<uint16>(temp & 0xFFFF);
++min_num;
++max_num;
++r;
}
while (max_num != max_end)
{
// add *max_num and the current carry
temp = *max_num + (temp>>16);
// put the low word of temp into *r
*r = static_cast<uint16>(temp & 0xFFFF);
++max_num;
++r;
}
// check if there was a final carry
if ((temp>>16) != 0)
{
result->digits_used = max_digits_used + 1;
// put the carry into the most significant digit in the result
*r = static_cast<uint16>(temp>>16);
}
else
{
result->digits_used = max_digits_used;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
long_sub (
const data_record* lhs,
const data_record* rhs,
data_record* result
) const
{
const uint16* number1 = lhs->number;
const uint16* number2 = rhs->number;
const uint16* end = number2 + rhs->digits_used;
uint16* r = result->number;
uint32 temp =0;
while (number2 != end)
{
// subtract *number2 from *number1 and then subtract any carry
temp = *number1 - *number2 - (temp>>31);
// put the low word of temp into *r
*r = static_cast<uint16>(temp & 0xFFFF);
++number1;
++number2;
++r;
}
end = lhs->number + lhs->digits_used;
while (number1 != end)
{
// subtract the carry from *number1
temp = *number1 - (temp>>31);
// put the low word of temp into *r
*r = static_cast<uint16>(temp & 0xFFFF);
++number1;
++r;
}
result->digits_used = lhs->digits_used;
// adjust the number of digits used appropriately
--r;
while (*r == 0 && result->digits_used > 1)
{
--r;
--result->digits_used;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
long_div (
const data_record* lhs,
const data_record* rhs,
data_record* result,
data_record* remainder
) const
{
// zero result
result->digits_used = 1;
*(result->number) = 0;
uint16* a;
uint16* b;
uint16* end;
// copy lhs into remainder
remainder->digits_used = lhs->digits_used;
a = remainder->number;
end = a + remainder->digits_used;
b = lhs->number;
while (a != end)
{
*a = *b;
++a;
++b;
}
// if rhs is bigger than lhs then result == 0 and remainder == lhs
// so then we can quit right now
if (is_less_than(lhs,rhs))
{
return;
}
// make a temporary number
data_record temp(lhs->digits_used + slack);
// shift rhs left until it is one shift away from being larger than lhs and
// put the number of left shifts necessary into shifts
uint32 shifts;
shifts = (lhs->digits_used - rhs->digits_used) * 16;
shift_left(rhs,&temp,shifts);
// while (lhs > temp)
while (is_less_than(&temp,lhs))
{
shift_left(&temp,&temp,1);
++shifts;
}
// make sure lhs isn't smaller than temp
while (is_less_than(lhs,&temp))
{
shift_right(&temp,&temp);
--shifts;
}
// we want to execute the loop shifts +1 times
++shifts;
while (shifts != 0)
{
shift_left(result,result,1);
// if (temp <= remainder)
if (!is_less_than(remainder,&temp))
{
long_sub(remainder,&temp,remainder);
// increment result
uint16* r = result->number;
uint16* end = r + result->digits_used;
while (true)
{
++(*r);
// if there was no carry then we are done
if (*r != 0)
break;
++r;
// if we hit the end of r and there is still a carry then
// the next digit of r is 1 and there is one more digit used
if (r == end)
{
*r = 1;
++(result->digits_used);
break;
}
}
}
shift_right(&temp,&temp);
--shifts;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
long_mul (
const data_record* lhs,
const data_record* rhs,
data_record* result
) const
{
// make result be zero
result->digits_used = 1;
*(result->number) = 0;
const data_record* aa;
const data_record* bb;
if (lhs->digits_used < rhs->digits_used)
{
// make copies of lhs and rhs and give them an appropriate amount of
// extra memory so there won't be any overflows
aa = lhs;
bb = rhs;
}
else
{
// make copies of lhs and rhs and give them an appropriate amount of
// extra memory so there won't be any overflows
aa = rhs;
bb = lhs;
}
// this is where we actually copy lhs and rhs
data_record b(*bb,aa->digits_used+slack); // the larger(approximately) of lhs and rhs
uint32 shift_value = 0;
uint16* anum = aa->number;
uint16* end = anum + aa->digits_used;
while (anum != end )
{
uint16 bit = 0x0001;
for (int i = 0; i < 16; ++i)
{
// if the specified bit of a is 1
if ((*anum & bit) != 0)
{
shift_left(&b,&b,shift_value);
shift_value = 0;
long_add(&b,result,result);
}
++shift_value;
bit <<= 1;
}
++anum;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
shift_left (
const data_record* data,
data_record* result,
uint32 shift_amount
) const
{
uint32 offset = shift_amount/16;
shift_amount &= 0xf; // same as shift_amount %= 16;
uint16* r = result->number + data->digits_used + offset; // result
uint16* end = data->number;
uint16* s = end + data->digits_used; // source
const uint32 temp = 16 - shift_amount;
*r = (*(--s) >> temp);
// set the number of digits used in the result
// if the upper bits from *s were zero then don't count this first word
if (*r == 0)
{
result->digits_used = data->digits_used + offset;
}
else
{
result->digits_used = data->digits_used + offset + 1;
}
--r;
while (s != end)
{
*r = ((*s << shift_amount) | ( *(s-1) >> temp));
--r;
--s;
}
*r = *s << shift_amount;
// now zero the rest of the result
end = result->number;
while (r != end)
*(--r) = 0;
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
shift_right (
const data_record* data,
data_record* result
) const
{
uint16* r = result->number; // result
uint16* s = data->number; // source
uint16* end = s + data->digits_used - 1;
while (s != end)
{
*r = (*s >> 1) | (*(s+1) << 15);
++r;
++s;
}
*r = *s >> 1;
// calculate the new number for digits_used
if (*r == 0)
{
if (data->digits_used != 1)
result->digits_used = data->digits_used - 1;
else
result->digits_used = 1;
}
else
{
result->digits_used = data->digits_used;
}
}
// ----------------------------------------------------------------------------------------
bool bigint_kernel_1::
is_less_than (
const data_record* lhs,
const data_record* rhs
) const
{
uint32 lhs_digits_used = lhs->digits_used;
uint32 rhs_digits_used = rhs->digits_used;
// if lhs is definitely less than rhs
if (lhs_digits_used < rhs_digits_used )
return true;
// if lhs is definitely greater than rhs
else if (lhs_digits_used > rhs_digits_used)
return false;
else
{
uint16* end = lhs->number;
uint16* l = end + lhs_digits_used;
uint16* r = rhs->number + rhs_digits_used;
while (l != end)
{
--l;
--r;
if (*l < *r)
return true;
else if (*l > *r)
return false;
}
// at this point we know that they are equal
return false;
}
}
// ----------------------------------------------------------------------------------------
bool bigint_kernel_1::
is_equal_to (
const data_record* lhs,
const data_record* rhs
) const
{
// if lhs and rhs are definitely not equal
if (lhs->digits_used != rhs->digits_used )
{
return false;
}
else
{
uint16* l = lhs->number;
uint16* r = rhs->number;
uint16* end = l + lhs->digits_used;
while (l != end)
{
if (*l != *r)
return false;
++l;
++r;
}
// at this point we know that they are equal
return true;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
increment (
const data_record* source,
data_record* dest
) const
{
uint16* s = source->number;
uint16* d = dest->number;
uint16* end = s + source->digits_used;
while (true)
{
*d = *s + 1;
// if there was no carry then break out of the loop
if (*d != 0)
{
dest->digits_used = source->digits_used;
// copy the rest of the digits over to d
++d; ++s;
while (s != end)
{
*d = *s;
++d;
++s;
}
break;
}
++s;
// if we have hit the end of s and there was a carry up to this point
// then just make the next digit 1 and add one to the digits used
if (s == end)
{
++d;
dest->digits_used = source->digits_used + 1;
*d = 1;
break;
}
++d;
}
}
// ----------------------------------------------------------------------------------------
void bigint_kernel_1::
decrement (
const data_record* source,
data_record* dest
) const
{
uint16* s = source->number;
uint16* d = dest->number;
uint16* end = s + source->digits_used;
while (true)
{
*d = *s - 1;
// if there was no carry then break out of the loop
if (*d != 0xFFFF)
{
// if we lost a digit in the subtraction
if (*d == 0 && s+1 == end)
{
if (source->digits_used == 1)
dest->digits_used = 1;
else
dest->digits_used = source->digits_used - 1;
}
else
{
dest->digits_used = source->digits_used;
}
break;
}
else
{
++d;
++s;
}
}
// copy the rest of the digits over to d
++d;
++s;
while (s != end)
{
*d = *s;
++d;
++s;
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BIGINT_KERNEL_1_CPp_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment