Commit 0886042c authored by lishen's avatar lishen
Browse files

dlib from github, version=19.24

parent 5b127120
Pipeline #262 failed with stages
in 0 seconds
// Copyright (C) 2008 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_
#define DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_
#include "../algs.h"
#include "../member_function_pointer.h"
#include "bound_function_pointer_kernel_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace bfp1_helpers
{
template <typename T> struct strip { typedef T type; };
template <typename T> struct strip<T&> { typedef T type; };
// ------------------------------------------------------------------------------------
class bound_function_helper_base_base
{
public:
virtual ~bound_function_helper_base_base(){}
virtual void call() const = 0;
virtual bool is_set() const = 0;
virtual void clone(void* ptr) const = 0;
};
// ------------------------------------------------------------------------------------
template <typename T1, typename T2, typename T3, typename T4>
class bound_function_helper_base : public bound_function_helper_base_base
{
public:
bound_function_helper_base():arg1(0), arg2(0), arg3(0), arg4(0) {}
typename strip<T1>::type* arg1;
typename strip<T2>::type* arg2;
typename strip<T3>::type* arg3;
typename strip<T4>::type* arg4;
member_function_pointer<T1,T2,T3,T4> mfp;
};
// ----------------
template <typename F, typename T1 = void, typename T2 = void, typename T3 = void, typename T4 = void>
class bound_function_helper : public bound_function_helper_base<T1,T2,T3,T4>
{
public:
void call() const
{
(*fp)(*this->arg1, *this->arg2, *this->arg3, *this->arg4);
}
typename strip<F>::type* fp;
};
template <typename T1, typename T2, typename T3, typename T4>
class bound_function_helper<void,T1,T2,T3,T4> : public bound_function_helper_base<T1,T2,T3,T4>
{
public:
void call() const
{
if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3, *this->arg4);
else if (fp) fp(*this->arg1, *this->arg2, *this->arg3, *this->arg4);
}
void (*fp)(T1, T2, T3, T4);
};
// ----------------
template <typename F>
class bound_function_helper<F,void,void,void,void> : public bound_function_helper_base<void,void,void,void>
{
public:
void call() const
{
(*fp)();
}
typename strip<F>::type* fp;
};
template <>
class bound_function_helper<void,void,void,void,void> : public bound_function_helper_base<void,void,void,void>
{
public:
void call() const
{
if (this->mfp) this->mfp();
else if (fp) fp();
}
void (*fp)();
};
// ----------------
template <typename F, typename T1>
class bound_function_helper<F,T1,void,void,void> : public bound_function_helper_base<T1,void,void,void>
{
public:
void call() const
{
(*fp)(*this->arg1);
}
typename strip<F>::type* fp;
};
template <typename T1>
class bound_function_helper<void,T1,void,void,void> : public bound_function_helper_base<T1,void,void,void>
{
public:
void call() const
{
if (this->mfp) this->mfp(*this->arg1);
else if (fp) fp(*this->arg1);
}
void (*fp)(T1);
};
// ----------------
template <typename F, typename T1, typename T2>
class bound_function_helper<F,T1,T2,void,void> : public bound_function_helper_base<T1,T2,void,void>
{
public:
void call() const
{
(*fp)(*this->arg1, *this->arg2);
}
typename strip<F>::type* fp;
};
template <typename T1, typename T2>
class bound_function_helper<void,T1,T2,void,void> : public bound_function_helper_base<T1,T2,void,void>
{
public:
void call() const
{
if (this->mfp) this->mfp(*this->arg1, *this->arg2);
else if (fp) fp(*this->arg1, *this->arg2);
}
void (*fp)(T1, T2);
};
// ----------------
template <typename F, typename T1, typename T2, typename T3>
class bound_function_helper<F,T1,T2,T3,void> : public bound_function_helper_base<T1,T2,T3,void>
{
public:
void call() const
{
(*fp)(*this->arg1, *this->arg2, *this->arg3);
}
typename strip<F>::type* fp;
};
template <typename T1, typename T2, typename T3>
class bound_function_helper<void,T1,T2,T3,void> : public bound_function_helper_base<T1,T2,T3,void>
{
public:
void call() const
{
if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3);
else if (fp) fp(*this->arg1, *this->arg2, *this->arg3);
}
void (*fp)(T1, T2, T3);
};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
template <typename T>
class bound_function_helper_T : public T
{
public:
bound_function_helper_T(){ this->fp = 0;}
bool is_set() const
{
return this->fp != 0 || this->mfp.is_set();
}
template <unsigned long mem_size>
void safe_clone(stack_based_memory_block<mem_size>& buf)
{
// This is here just to validate the assumption that our block of memory we have made
// in bf_memory is the right size to store the data for this object. If you
// get a compiler error on this line then email me :)
COMPILE_TIME_ASSERT(sizeof(bound_function_helper_T) <= mem_size);
clone(buf.get());
}
void clone (void* ptr) const
{
bound_function_helper_T* p = new(ptr) bound_function_helper_T();
p->arg1 = this->arg1;
p->arg2 = this->arg2;
p->arg3 = this->arg3;
p->arg4 = this->arg4;
p->fp = this->fp;
p->mfp = this->mfp;
}
};
}
// ----------------------------------------------------------------------------------------
class bound_function_pointer
{
typedef bfp1_helpers::bound_function_helper_T<bfp1_helpers::bound_function_helper<void,int> > bf_null_type;
public:
// These typedefs are here for backwards compatibility with previous versions of
// dlib.
typedef bound_function_pointer kernel_1a;
typedef bound_function_pointer kernel_1a_c;
bound_function_pointer (
) { bf_null_type().safe_clone(bf_memory); }
bound_function_pointer (
const bound_function_pointer& item
) { item.bf()->clone(bf_memory.get()); }
~bound_function_pointer()
{ destroy_bf_memory(); }
bound_function_pointer& operator= (
const bound_function_pointer& item
) { bound_function_pointer(item).swap(*this); return *this; }
void clear (
) { bound_function_pointer().swap(*this); }
bool is_set (
) const
{
return bf()->is_set();
}
void swap (
bound_function_pointer& item
)
{
// make a temp copy of item
bound_function_pointer temp(item);
// destory the stuff in item
item.destroy_bf_memory();
// copy *this into item
bf()->clone(item.bf_memory.get());
// destory the stuff in this
destroy_bf_memory();
// copy temp into *this
temp.bf()->clone(bf_memory.get());
}
void operator() (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_set() == true ,
"\tvoid bound_function_pointer::operator()"
<< "\n\tYou must call set() before you can use this function"
<< "\n\tthis: " << this
);
bf()->call();
}
private:
struct dummy{ void nonnull() {}};
typedef void (dummy::*safe_bool)();
public:
operator safe_bool () const { return is_set() ? &dummy::nonnull : 0; }
bool operator!() const { return !is_set(); }
// -------------------------------------------
// set function object overloads
// -------------------------------------------
template <typename F>
void set (
F& function_object
)
{
COMPILE_TIME_ASSERT(std::is_function<F>::value == false);
COMPILE_TIME_ASSERT(std::is_pointer<F>::value == false);
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<F> > bf_helper_type;
bf_helper_type temp;
temp.fp = &function_object;
temp.safe_clone(bf_memory);
}
template <typename F, typename A1 >
void set (
F& function_object,
A1& arg1
)
{
COMPILE_TIME_ASSERT(std::is_function<F>::value == false);
COMPILE_TIME_ASSERT(std::is_pointer<F>::value == false);
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<F,A1> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.fp = &function_object;
temp.safe_clone(bf_memory);
}
template <typename F, typename A1, typename A2 >
void set (
F& function_object,
A1& arg1,
A2& arg2
)
{
COMPILE_TIME_ASSERT(std::is_function<F>::value == false);
COMPILE_TIME_ASSERT(std::is_pointer<F>::value == false);
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<F,A1,A2> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.fp = &function_object;
temp.safe_clone(bf_memory);
}
template <typename F, typename A1, typename A2, typename A3 >
void set (
F& function_object,
A1& arg1,
A2& arg2,
A3& arg3
)
{
COMPILE_TIME_ASSERT(std::is_function<F>::value == false);
COMPILE_TIME_ASSERT(std::is_pointer<F>::value == false);
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<F,A1,A2,A3> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.arg3 = &arg3;
temp.fp = &function_object;
temp.safe_clone(bf_memory);
}
template <typename F, typename A1, typename A2, typename A3, typename A4>
void set (
F& function_object,
A1& arg1,
A2& arg2,
A3& arg3,
A4& arg4
)
{
COMPILE_TIME_ASSERT(std::is_function<F>::value == false);
COMPILE_TIME_ASSERT(std::is_pointer<F>::value == false);
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<F,A1,A2,A3,A4> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.arg3 = &arg3;
temp.arg4 = &arg4;
temp.fp = &function_object;
temp.safe_clone(bf_memory);
}
// -------------------------------------------
// set mfp overloads
// -------------------------------------------
template <typename T>
void set (
T& object,
void (T::*funct)()
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void> > bf_helper_type;
bf_helper_type temp;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
template <typename T >
void set (
const T& object,
void (T::*funct)()const
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void> > bf_helper_type;
bf_helper_type temp;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
// -------------------------------------------
template <typename T, typename T1, typename A1 >
void set (
T& object,
void (T::*funct)(T1),
A1& arg1
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
template <typename T, typename T1, typename A1 >
void set (
const T& object,
void (T::*funct)(T1)const,
A1& arg1
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
// ----------------
template <typename T, typename T1, typename A1,
typename T2, typename A2>
void set (
T& object,
void (T::*funct)(T1, T2),
A1& arg1,
A2& arg2
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
template <typename T, typename T1, typename A1,
typename T2, typename A2>
void set (
const T& object,
void (T::*funct)(T1, T2)const,
A1& arg1,
A2& arg2
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
// ----------------
template <typename T, typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3>
void set (
T& object,
void (T::*funct)(T1, T2, T3),
A1& arg1,
A2& arg2,
A3& arg3
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.arg3 = &arg3;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
template <typename T, typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3>
void set (
const T& object,
void (T::*funct)(T1, T2, T3)const,
A1& arg1,
A2& arg2,
A3& arg3
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.arg3 = &arg3;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
// ----------------
template <typename T, typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3,
typename T4, typename A4>
void set (
T& object,
void (T::*funct)(T1, T2, T3, T4),
A1& arg1,
A2& arg2,
A3& arg3,
A4& arg4
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3,T4> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.arg3 = &arg3;
temp.arg4 = &arg4;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
template <typename T, typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3,
typename T4, typename A4>
void set (
const T& object,
void (T::*funct)(T1, T2, T3, T4)const,
A1& arg1,
A2& arg2,
A3& arg3,
A4& arg4
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3,T4> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.arg3 = &arg3;
temp.arg4 = &arg4;
temp.mfp.set(object,funct);
temp.safe_clone(bf_memory);
}
// -------------------------------------------
// set fp overloads
// -------------------------------------------
void set (
void (*funct)()
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void> > bf_helper_type;
bf_helper_type temp;
temp.fp = funct;
temp.safe_clone(bf_memory);
}
template <typename T1, typename A1>
void set (
void (*funct)(T1),
A1& arg1
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.fp = funct;
temp.safe_clone(bf_memory);
}
template <typename T1, typename A1,
typename T2, typename A2>
void set (
void (*funct)(T1, T2),
A1& arg1,
A2& arg2
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.fp = funct;
temp.safe_clone(bf_memory);
}
template <typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3>
void set (
void (*funct)(T1, T2, T3),
A1& arg1,
A2& arg2,
A3& arg3
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.arg3 = &arg3;
temp.fp = funct;
temp.safe_clone(bf_memory);
}
template <typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3,
typename T4, typename A4>
void set (
void (*funct)(T1, T2, T3, T4),
A1& arg1,
A2& arg2,
A3& arg3,
A4& arg4
)
{
using namespace bfp1_helpers;
destroy_bf_memory();
typedef bound_function_helper_T<bound_function_helper<void,T1,T2,T3,T4> > bf_helper_type;
bf_helper_type temp;
temp.arg1 = &arg1;
temp.arg2 = &arg2;
temp.arg3 = &arg3;
temp.arg4 = &arg4;
temp.fp = funct;
temp.safe_clone(bf_memory);
}
// -------------------------------------------
private:
stack_based_memory_block<sizeof(bf_null_type)> bf_memory;
void destroy_bf_memory (
)
{
// Honestly, this probably doesn't even do anything but I'm putting
// it here just for good measure.
bf()->~bound_function_helper_base_base();
}
bfp1_helpers::bound_function_helper_base_base* bf ()
{ return static_cast<bfp1_helpers::bound_function_helper_base_base*>(bf_memory.get()); }
const bfp1_helpers::bound_function_helper_base_base* bf () const
{ return static_cast<const bfp1_helpers::bound_function_helper_base_base*>(bf_memory.get()); }
};
// ----------------------------------------------------------------------------------------
inline void swap (
bound_function_pointer& a,
bound_function_pointer& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_
// Copyright (C) 2008 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_
#ifdef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_
namespace dlib
{
// ----------------------------------------------------------------------------------------
class bound_function_pointer
{
/*!
INITIAL VALUE
is_set() == false
WHAT THIS OBJECT REPRESENTS
This object represents a function with all its arguments bound to
specific objects. For example:
void test(int& var) { var = var+1; }
bound_function_pointer funct;
int a = 4;
funct.set(test,a); // bind the variable a to the first argument of the test() function
// at this point a == 4
funct();
// after funct() is called a == 5
!*/
public:
bound_function_pointer (
);
/*!
ensures
- #*this is properly initialized
!*/
bound_function_pointer(
const bound_function_pointer& item
);
/*!
ensures
- *this == item
!*/
~bound_function_pointer (
);
/*!
ensures
- any resources associated with *this have been released
!*/
bound_function_pointer& operator=(
const bound_function_pointer& item
);
/*!
ensures
- *this == item
!*/
void clear(
);
/*!
ensures
- #*this has its initial value
!*/
bool is_set (
) const;
/*!
ensures
- if (this->set() has been called) then
- returns true
- else
- returns false
!*/
operator some_undefined_pointer_type (
) const;
/*!
ensures
- if (is_set()) then
- returns a non 0 value
- else
- returns a 0 value
!*/
bool operator! (
) const;
/*!
ensures
- returns !is_set()
!*/
void operator () (
) const;
/*!
requires
- is_set() == true
ensures
- calls the bound function on the object(s) specified by the last
call to this->set()
throws
- any exception thrown by the function specified by
the previous call to this->set().
If any of these exceptions are thrown then the call to this
function will have no effect on *this.
!*/
void swap (
bound_function_pointer& item
);
/*!
ensures
- swaps *this and item
!*/
// ----------------------
template <typename F>
void set (
F& function_object
);
/*!
requires
- function_object() is a valid expression
ensures
- #is_set() == true
- calls to this->operator() will call function_object()
(This seems pointless but it is a useful base case)
!*/
template < typename T>
void set (
T& object,
void (T::*funct)()
);
/*!
requires
- funct == a valid member function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)()
!*/
template < typename T>
void set (
const T& object,
void (T::*funct)()const
);
/*!
requires
- funct == a valid bound function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)()
!*/
void set (
void (*funct)()
);
/*!
requires
- funct == a valid function pointer
ensures
- #is_set() == true
- calls to this->operator() will call funct()
!*/
// ----------------------
template <typename F, typename A1 >
void set (
F& function_object,
A1& arg1
);
/*!
requires
- function_object(arg1) is a valid expression
ensures
- #is_set() == true
- calls to this->operator() will call function_object(arg1)
!*/
template < typename T, typename T1, typename A1 >
void set (
T& object,
void (T::*funct)(T1),
A1& arg1
);
/*!
requires
- funct == a valid member function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)(arg1)
!*/
template < typename T, typename T1, typename A1 >
void set (
const T& object,
void (T::*funct)(T1)const,
A1& arg1
);
/*!
requires
- funct == a valid bound function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)(arg1)
!*/
template <typename T1, typename A1>
void set (
void (*funct)(T1),
A1& arg1
);
/*!
requires
- funct == a valid function pointer
ensures
- #is_set() == true
- calls to this->operator() will call funct(arg1)
!*/
// ----------------------
template <typename F, typename A1, typename A2 >
void set (
F& function_object,
A1& arg1,
A2& arg2
);
/*!
requires
- function_object(arg1,arg2) is a valid expression
ensures
- #is_set() == true
- calls to this->operator() will call function_object(arg1,arg2)
!*/
template < typename T, typename T1, typename A1,
typename T2, typename A2>
void set (
T& object,
void (T::*funct)(T1,T2),
A1& arg1,
A2& arg2
);
/*!
requires
- funct == a valid member function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)(arg1,arg2)
!*/
template < typename T, typename T1, typename A1,
typename T2, typename A2>
void set (
const T& object,
void (T::*funct)(T1,T2)const,
A1& arg1,
A2& arg2
);
/*!
requires
- funct == a valid bound function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)(arg1,arg2)
!*/
template <typename T1, typename A1,
typename T2, typename A2>
void set (
void (*funct)(T1,T2),
A1& arg1,
A2& arg2
);
/*!
requires
- funct == a valid function pointer
ensures
- #is_set() == true
- calls to this->operator() will call funct(arg1,arg2)
!*/
// ----------------------
template <typename F, typename A1, typename A2, typename A3 >
void set (
F& function_object,
A1& arg1,
A2& arg2,
A3& arg3
);
/*!
requires
- function_object(arg1,arg2,arg3) is a valid expression
ensures
- #is_set() == true
- calls to this->operator() will call function_object(arg1,arg2,arg3)
!*/
template < typename T, typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3>
void set (
T& object,
void (T::*funct)(T1,T2,T3),
A1& arg1,
A2& arg2,
A3& arg3
);
/*!
requires
- funct == a valid member function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)(arg1,arg2,arg3)
!*/
template < typename T, typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3>
void set (
const T& object,
void (T::*funct)(T1,T2,T3)const,
A1& arg1,
A2& arg2,
A3& arg3
);
/*!
requires
- funct == a valid bound function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)(arg1,arg2,arg3)
!*/
template <typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3>
void set (
void (*funct)(T1,T2,T3),
A1& arg1,
A2& arg2,
A3& arg3
);
/*!
requires
- funct == a valid function pointer
ensures
- #is_set() == true
- calls to this->operator() will call funct(arg1,arg2,arg3)
!*/
// ----------------------
template <typename F, typename A1, typename A2, typename A3, typename A4>
void set (
F& function_object,
A1& arg1,
A2& arg2,
A3& arg3,
A4& arg4
);
/*!
requires
- function_object(arg1,arg2,arg3,arg4) is a valid expression
ensures
- #is_set() == true
- calls to this->operator() will call function_object(arg1,arg2,arg3,arg4)
!*/
template < typename T, typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3,
typename T4, typename A4>
void set (
T& object,
void (T::*funct)(T1,T2,T3,T4),
A1& arg1,
A2& arg2,
A3& arg3,
A4& arg4
);
/*!
requires
- funct == a valid member function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4)
!*/
template < typename T, typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3,
typename T4, typename A4>
void set (
const T& object,
void (T::*funct)(T1,T2,T3,T4)const,
A1& arg1,
A2& arg2,
A3& arg3,
A4& arg4
);
/*!
requires
- funct == a valid bound function pointer for class T
ensures
- #is_set() == true
- calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4)
!*/
template <typename T1, typename A1,
typename T2, typename A2,
typename T3, typename A3,
typename T4, typename A4>
void set (
void (*funct)(T1,T2,T3,T4),
A1& arg1,
A2& arg2,
A3& arg3,
A4& arg4
);
/*!
requires
- funct == a valid function pointer
ensures
- #is_set() == true
- calls to this->operator() will call funct(arg1,arg2,arg3,arg4)
!*/
};
// ----------------------------------------------------------------------------------------
inline void swap (
bound_function_pointer& a,
bound_function_pointer& b
) { a.swap(b); }
/*!
provides a global swap function
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifdef DLIB_ALL_SOURCE_END
#include "dlib_basic_cpp_build_tutorial.txt"
#endif
#ifndef DLIB_BRIdGE_
#define DLIB_BRIdGE_
#include "bridge/bridge.h"
#endif // DLIB_BRIdGE_
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BRIDGe_Hh_
#define DLIB_BRIDGe_Hh_
#include <iostream>
#include <memory>
#include <string>
#include "bridge_abstract.h"
#include "../pipe.h"
#include "../threads.h"
#include "../serialize.h"
#include "../sockets.h"
#include "../sockstreambuf.h"
#include "../logger.h"
#include "../algs.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
struct connect_to_ip_and_port
{
connect_to_ip_and_port (
const std::string& ip_,
unsigned short port_
): ip(ip_), port(port_)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_ip_address(ip) && port != 0,
"\t connect_to_ip_and_port()"
<< "\n\t Invalid inputs were given to this function"
<< "\n\t ip: " << ip
<< "\n\t port: " << port
<< "\n\t this: " << this
);
}
private:
friend class bridge;
const std::string ip;
const unsigned short port;
};
inline connect_to_ip_and_port connect_to (
const network_address& addr
)
{
// make sure requires clause is not broken
DLIB_ASSERT(addr.port != 0,
"\t connect_to_ip_and_port()"
<< "\n\t The TCP port to connect to can't be 0."
<< "\n\t addr.port: " << addr.port
);
if (is_ip_address(addr.host_address))
{
return connect_to_ip_and_port(addr.host_address, addr.port);
}
else
{
std::string ip;
if(hostname_to_ip(addr.host_address,ip))
throw socket_error(ERESOLVE,"unable to resolve '" + addr.host_address + "' in connect_to()");
return connect_to_ip_and_port(ip, addr.port);
}
}
struct listen_on_port
{
listen_on_port(
unsigned short port_
) : port(port_)
{
// make sure requires clause is not broken
DLIB_ASSERT( port != 0,
"\t listen_on_port()"
<< "\n\t Invalid inputs were given to this function"
<< "\n\t port: " << port
<< "\n\t this: " << this
);
}
private:
friend class bridge;
const unsigned short port;
};
template <typename pipe_type>
struct bridge_transmit_decoration
{
bridge_transmit_decoration (
pipe_type& p_
) : p(p_) {}
private:
friend class bridge;
pipe_type& p;
};
template <typename pipe_type>
bridge_transmit_decoration<pipe_type> transmit ( pipe_type& p) { return bridge_transmit_decoration<pipe_type>(p); }
template <typename pipe_type>
struct bridge_receive_decoration
{
bridge_receive_decoration (
pipe_type& p_
) : p(p_) {}
private:
friend class bridge;
pipe_type& p;
};
template <typename pipe_type>
bridge_receive_decoration<pipe_type> receive ( pipe_type& p) { return bridge_receive_decoration<pipe_type>(p); }
// ----------------------------------------------------------------------------------------
struct bridge_status
{
bridge_status() : is_connected(false), foreign_port(0){}
bool is_connected;
unsigned short foreign_port;
std::string foreign_ip;
};
inline void serialize ( const bridge_status& , std::ostream& )
{
throw serialization_error("It is illegal to serialize bridge_status objects.");
}
inline void deserialize ( bridge_status& , std::istream& )
{
throw serialization_error("It is illegal to serialize bridge_status objects.");
}
// ----------------------------------------------------------------------------------------
namespace impl_brns
{
class impl_bridge_base
{
public:
virtual ~impl_bridge_base() {}
virtual bridge_status get_bridge_status (
) const = 0;
};
template <
typename transmit_pipe_type,
typename receive_pipe_type
>
class impl_bridge : public impl_bridge_base, private noncopyable, private multithreaded_object
{
/*!
CONVENTION
- if (list) then
- this object is supposed to be listening on the list object for incoming
connections when not connected.
- else
- this object is supposed to be attempting to connect to ip:port when
not connected.
- get_bridge_status() == current_bs
!*/
public:
impl_bridge (
unsigned short listen_port,
transmit_pipe_type* transmit_pipe_,
receive_pipe_type* receive_pipe_
) :
s(m),
receive_thread_active(false),
transmit_thread_active(false),
port(0),
transmit_pipe(transmit_pipe_),
receive_pipe(receive_pipe_),
dlog("dlib.bridge"),
keepalive_code(0),
message_code(1)
{
int status = create_listener(list, listen_port);
if (status == PORTINUSE)
{
std::ostringstream sout;
sout << "Error, the port " << listen_port << " is already in use.";
throw socket_error(EPORT_IN_USE, sout.str());
}
else if (status == OTHER_ERROR)
{
throw socket_error("Unable to create listening socket for an unknown reason.");
}
register_thread(*this, &impl_bridge::transmit_thread);
register_thread(*this, &impl_bridge::receive_thread);
register_thread(*this, &impl_bridge::connect_thread);
start();
}
impl_bridge (
const std::string ip_,
unsigned short port_,
transmit_pipe_type* transmit_pipe_,
receive_pipe_type* receive_pipe_
) :
s(m),
receive_thread_active(false),
transmit_thread_active(false),
port(port_),
ip(ip_),
transmit_pipe(transmit_pipe_),
receive_pipe(receive_pipe_),
dlog("dlib.bridge"),
keepalive_code(0),
message_code(1)
{
register_thread(*this, &impl_bridge::transmit_thread);
register_thread(*this, &impl_bridge::receive_thread);
register_thread(*this, &impl_bridge::connect_thread);
start();
}
~impl_bridge()
{
// tell the threads to terminate
stop();
// save current pipe enabled status so we can restore it to however
// it was before this destructor ran.
bool transmit_enabled = true;
bool receive_enabled = true;
// make any calls blocked on a pipe return immediately.
if (transmit_pipe)
{
transmit_enabled = transmit_pipe->is_dequeue_enabled();
transmit_pipe->disable_dequeue();
}
if (receive_pipe)
{
receive_enabled = receive_pipe->is_enqueue_enabled();
receive_pipe->disable_enqueue();
}
{
auto_mutex lock(m);
s.broadcast();
// Shutdown the connection if we have one. This will cause
// all blocked I/O calls to return an error.
if (con)
con->shutdown();
}
// wait for all the threads to terminate.
wait();
if (transmit_pipe && transmit_enabled)
transmit_pipe->enable_dequeue();
if (receive_pipe && receive_enabled)
receive_pipe->enable_enqueue();
}
bridge_status get_bridge_status (
) const
{
auto_mutex lock(current_bs_mutex);
return current_bs;
}
private:
template <typename pipe_type>
std::enable_if_t<std::is_convertible<bridge_status, typename pipe_type::type>::value> enqueue_bridge_status (
pipe_type* p,
const bridge_status& status
)
{
if (p)
{
typename pipe_type::type temp(status);
p->enqueue(temp);
}
}
template <typename pipe_type>
std::enable_if_t<!std::is_convertible<bridge_status, typename pipe_type::type>::value> enqueue_bridge_status (
pipe_type* ,
const bridge_status&
)
{
}
void connect_thread (
)
{
while (!should_stop())
{
auto_mutex lock(m);
int status = OTHER_ERROR;
if (list)
{
do
{
status = list->accept(con, 1000);
} while (status == TIMEOUT && !should_stop());
}
else
{
status = create_connection(con, port, ip);
}
if (should_stop())
break;
if (status != 0)
{
// The last connection attempt failed. So pause for a little bit before making another attempt.
s.wait_or_timeout(2000);
continue;
}
dlog << LINFO << "Established new connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << ".";
bridge_status temp_bs;
{ auto_mutex lock(current_bs_mutex);
current_bs.is_connected = true;
current_bs.foreign_port = con->get_foreign_port();
current_bs.foreign_ip = con->get_foreign_ip();
temp_bs = current_bs;
}
enqueue_bridge_status(receive_pipe, temp_bs);
receive_thread_active = true;
transmit_thread_active = true;
s.broadcast();
// Wait for the transmit and receive threads to end before we continue.
// This way we don't invalidate the con pointer while it is in use.
while (receive_thread_active || transmit_thread_active)
s.wait();
dlog << LINFO << "Closed connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << ".";
{ auto_mutex lock(current_bs_mutex);
current_bs.is_connected = false;
current_bs.foreign_port = con->get_foreign_port();
current_bs.foreign_ip = con->get_foreign_ip();
temp_bs = current_bs;
}
enqueue_bridge_status(receive_pipe, temp_bs);
}
}
void receive_thread (
)
{
while (true)
{
// wait until we have a connection
{ auto_mutex lock(m);
while (!receive_thread_active && !should_stop())
{
s.wait();
}
if (should_stop())
break;
}
try
{
if (receive_pipe)
{
sockstreambuf buf(con);
std::istream in(&buf);
typename receive_pipe_type::type item;
// This isn't necessary but doing it avoids a warning about
// item being uninitialized sometimes.
assign_zero_if_built_in_scalar_type(item);
while (in.peek() != EOF)
{
unsigned char code;
in.read((char*)&code, sizeof(code));
if (code == message_code)
{
deserialize(item, in);
receive_pipe->enqueue(item);
}
}
}
else
{
// Since we don't have a receive pipe to put messages into we will
// just read the bytes from the connection and ignore them.
char buf[1000];
while (con->read(buf, sizeof(buf)) > 0) ;
}
}
catch (std::bad_alloc& )
{
dlog << LERROR << "std::bad_alloc thrown while deserializing message from "
<< con->get_foreign_ip() << ":" << con->get_foreign_port();
}
catch (dlib::serialization_error& e)
{
dlog << LERROR << "dlib::serialization_error thrown while deserializing message from "
<< con->get_foreign_ip() << ":" << con->get_foreign_port()
<< ".\nThe exception error message is: \n" << e.what();
}
catch (std::exception& e)
{
dlog << LERROR << "std::exception thrown while deserializing message from "
<< con->get_foreign_ip() << ":" << con->get_foreign_port()
<< ".\nThe exception error message is: \n" << e.what();
}
con->shutdown();
auto_mutex lock(m);
receive_thread_active = false;
s.broadcast();
}
auto_mutex lock(m);
receive_thread_active = false;
s.broadcast();
}
void transmit_thread (
)
{
while (true)
{
// wait until we have a connection
{ auto_mutex lock(m);
while (!transmit_thread_active && !should_stop())
{
s.wait();
}
if (should_stop())
break;
}
try
{
sockstreambuf buf(con);
std::ostream out(&buf);
typename transmit_pipe_type::type item;
// This isn't necessary but doing it avoids a warning about
// item being uninitialized sometimes.
assign_zero_if_built_in_scalar_type(item);
while (out)
{
bool dequeue_timed_out = false;
if (transmit_pipe )
{
if (transmit_pipe->dequeue_or_timeout(item,1000))
{
out.write((char*)&message_code, sizeof(message_code));
serialize(item, out);
if (transmit_pipe->size() == 0)
out.flush();
continue;
}
dequeue_timed_out = (transmit_pipe->is_enabled() && transmit_pipe->is_dequeue_enabled());
}
// Pause for about a second. Note that we use a wait_or_timeout() call rather
// than sleep() here because we want to wake up immediately if this object is
// being destructed rather than hang for a second.
if (!dequeue_timed_out)
{
auto_mutex lock(m);
if (should_stop())
break;
s.wait_or_timeout(1000);
}
// Just send the keepalive byte periodically so we can
// tell if the connection is alive.
out.write((char*)&keepalive_code, sizeof(keepalive_code));
out.flush();
}
}
catch (std::bad_alloc& )
{
dlog << LERROR << "std::bad_alloc thrown while serializing message to "
<< con->get_foreign_ip() << ":" << con->get_foreign_port();
}
catch (dlib::serialization_error& e)
{
dlog << LERROR << "dlib::serialization_error thrown while serializing message to "
<< con->get_foreign_ip() << ":" << con->get_foreign_port()
<< ".\nThe exception error message is: \n" << e.what();
}
catch (std::exception& e)
{
dlog << LERROR << "std::exception thrown while serializing message to "
<< con->get_foreign_ip() << ":" << con->get_foreign_port()
<< ".\nThe exception error message is: \n" << e.what();
}
con->shutdown();
auto_mutex lock(m);
transmit_thread_active = false;
s.broadcast();
}
auto_mutex lock(m);
transmit_thread_active = false;
s.broadcast();
}
mutex m;
signaler s;
bool receive_thread_active;
bool transmit_thread_active;
std::unique_ptr<connection> con;
std::unique_ptr<listener> list;
const unsigned short port;
const std::string ip;
transmit_pipe_type* const transmit_pipe;
receive_pipe_type* const receive_pipe;
logger dlog;
const unsigned char keepalive_code;
const unsigned char message_code;
mutex current_bs_mutex;
bridge_status current_bs;
};
}
// ----------------------------------------------------------------------------------------
class bridge : noncopyable
{
public:
bridge () {}
template < typename T, typename U, typename V >
bridge (
T network_parameters,
U pipe1,
V pipe2
) { reconfigure(network_parameters,pipe1,pipe2); }
template < typename T, typename U>
bridge (
T network_parameters,
U pipe
) { reconfigure(network_parameters,pipe); }
void clear (
)
{
pimpl.reset();
}
template < typename T, typename R >
void reconfigure (
listen_on_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe,
bridge_receive_decoration<R> receive_pipe
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
template < typename T, typename R >
void reconfigure (
listen_on_port network_parameters,
bridge_receive_decoration<R> receive_pipe,
bridge_transmit_decoration<T> transmit_pipe
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
template < typename T >
void reconfigure (
listen_on_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,T>(network_parameters.port, &transmit_pipe.p, 0)); }
template < typename R >
void reconfigure (
listen_on_port network_parameters,
bridge_receive_decoration<R> receive_pipe
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<R,R>(network_parameters.port, 0, &receive_pipe.p)); }
template < typename T, typename R >
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe,
bridge_receive_decoration<R> receive_pipe
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
template < typename T, typename R >
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_receive_decoration<R> receive_pipe,
bridge_transmit_decoration<T> transmit_pipe
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,R>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); }
template < typename R >
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_receive_decoration<R> receive_pipe
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<R,R>(network_parameters.ip, network_parameters.port, 0, &receive_pipe.p)); }
template < typename T >
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe
) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge<T,T>(network_parameters.ip, network_parameters.port, &transmit_pipe.p, 0)); }
bridge_status get_bridge_status (
) const
{
if (pimpl)
return pimpl->get_bridge_status();
else
return bridge_status();
}
private:
std::unique_ptr<impl_brns::impl_bridge_base> pimpl;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BRIDGe_Hh_
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_BRIDGe_ABSTRACT_
#ifdef DLIB_BRIDGe_ABSTRACT_
#include <string>
#include "../pipe/pipe_kernel_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
struct connect_to_ip_and_port
{
connect_to_ip_and_port (
const std::string& ip,
unsigned short port
);
/*!
requires
- is_ip_address(ip) == true
- port != 0
ensures
- this object will represent a request to make a TCP connection
to the given IP address and port number.
!*/
};
connect_to_ip_and_port connect_to (
const network_address& addr
);
/*!
requires
- addr.port != 0
ensures
- converts the given network_address object into a connect_to_ip_and_port
object.
!*/
struct listen_on_port
{
listen_on_port(
unsigned short port
);
/*!
requires
- port != 0
ensures
- this object will represent a request to listen on the given
port number for incoming TCP connections.
!*/
};
template <
typename pipe_type
>
bridge_transmit_decoration<pipe_type> transmit (
pipe_type& p
);
/*!
requires
- pipe_type is some kind of dlib::pipe object
- the objects in the pipe must be serializable
ensures
- Adds a type decoration to the given pipe, marking it as a transmit pipe, and
then returns it.
!*/
template <
typename pipe_type
>
bridge_receive_decoration<pipe_type> receive (
pipe_type& p
);
/*!
requires
- pipe_type is some kind of dlib::pipe object
- the objects in the pipe must be serializable
ensures
- Adds a type decoration to the given pipe, marking it as a receive pipe, and
then returns it.
!*/
// ----------------------------------------------------------------------------------------
struct bridge_status
{
/*!
WHAT THIS OBJECT REPRESENTS
This simple struct represents the state of a bridge object. A
bridge is either connected or not. If it is connected then it
is connected to a foreign host with an IP address and port number
as indicated by this object.
!*/
bridge_status(
);
/*!
ensures
- #is_connected == false
- #foreign_port == 0
- #foreign_ip == ""
!*/
bool is_connected;
unsigned short foreign_port;
std::string foreign_ip;
};
// ----------------------------------------------------------------------------------------
class bridge : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is a tool for bridging a dlib::pipe object between
two network connected applications.
Note also that this object contains a dlib::logger object
which will log various events taking place inside a bridge.
If you want to see these log messages then enable the logger
named "dlib.bridge".
BRIDGE PROTOCOL DETAILS
The bridge object creates a single TCP connection between
two applications. Whenever it sends an object from a pipe
over a TCP connection it sends a byte with the value 1 followed
immediately by the serialized copy of the object from the pipe.
The serialization is performed by calling the global serialize()
function.
Additionally, a bridge object will periodically send bytes with
a value of 0 to ensure the TCP connection remains alive. These
are just read and ignored.
!*/
public:
bridge (
);
/*!
ensures
- this object is properly initialized
- #get_bridge_status().is_connected == false
!*/
template <typename T, typename U, typename V>
bridge (
T network_parameters,
U pipe1,
V pipe2
);
/*!
requires
- T is of type connect_to_ip_and_port or listen_on_port
- U and V are of type bridge_transmit_decoration or bridge_receive_decoration,
however, U and V must be of different types (i.e. one is a receive type and
another a transmit type).
ensures
- this object is properly initialized
- performs: reconfigure(network_parameters, pipe1, pipe2)
(i.e. using this constructor is identical to using the default constructor
and then calling reconfigure())
!*/
template <typename T, typename U>
bridge (
T network_parameters,
U pipe
);
/*!
requires
- T is of type connect_to_ip_and_port or listen_on_port
- U is of type bridge_transmit_decoration or bridge_receive_decoration.
ensures
- this object is properly initialized
- performs: reconfigure(network_parameters, pipe)
(i.e. using this constructor is identical to using the default constructor
and then calling reconfigure())
!*/
~bridge (
);
/*!
ensures
- blocks until all resources associated with this object have been destroyed.
!*/
void clear (
);
/*!
ensures
- returns this object to its default constructed state. That is, it will
be inactive, neither maintaining a connection nor attempting to acquire one.
- Any active connections or listening sockets will be closed.
!*/
bridge_status get_bridge_status (
) const;
/*!
ensures
- returns the current status of this bridge object. In particular, returns
an object BS such that:
- BS.is_connected == true if and only if the bridge has an active TCP
connection to another computer.
- if (BS.is_connected) then
- BS.foreign_ip == the IP address of the remote host we are connected to.
- BS.foreign_port == the port number on the remote host we are connected to.
- else if (the bridge has previously been connected to a remote host but hasn't been
reconfigured or cleared since) then
- BS.foreign_ip == the IP address of the remote host we were connected to.
- BS.foreign_port == the port number on the remote host we were connected to.
- else
- BS.foreign_ip == ""
- BS.foreign_port == 0
!*/
template < typename T, typename R >
void reconfigure (
listen_on_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe,
bridge_receive_decoration<R> receive_pipe
);
/*!
ensures
- This object will begin listening on the port specified by network_parameters
for incoming TCP connections. Any previous bridge state is cleared out.
- Onces a connection is established we will:
- Stop accepting new connections.
- Begin dequeuing objects from the transmit pipe and serializing them over
the TCP connection.
- Begin deserializing objects from the TCP connection and enqueueing them
onto the receive pipe.
- if (the current TCP connection is lost) then
- This object goes back to listening for a new connection.
- if (the receive pipe can contain bridge_status objects) then
- Whenever the bridge's status changes the updated bridge_status will be
enqueued onto the receive pipe unless the change was a TCP disconnect
resulting from a user calling reconfigure(), clear(), or destructing this
bridge. The status contents are defined by get_bridge_status().
throws
- socket_error
This exception is thrown if we are unable to open the listening socket.
!*/
template < typename T, typename R >
void reconfigure (
listen_on_port network_parameters,
bridge_receive_decoration<R> receive_pipe,
bridge_transmit_decoration<T> transmit_pipe
);
/*!
ensures
- performs reconfigure(network_parameters, transmit_pipe, receive_pipe)
!*/
template < typename T >
void reconfigure (
listen_on_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe
);
/*!
ensures
- This function is identical to the above two reconfigure() functions
except that there is no receive pipe.
!*/
template < typename R >
void reconfigure (
listen_on_port network_parameters,
bridge_receive_decoration<R> receive_pipe
);
/*!
ensures
- This function is identical to the above three reconfigure() functions
except that there is no transmit pipe.
!*/
template <typename T, typename R>
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe,
bridge_receive_decoration<R> receive_pipe
);
/*!
ensures
- This object will begin making TCP connection attempts to the IP address and port
specified by network_parameters. Any previous bridge state is cleared out.
- Onces a connection is established we will:
- Stop attempting new connections.
- Begin dequeuing objects from the transmit pipe and serializing them over
the TCP connection.
- Begin deserializing objects from the TCP connection and enqueueing them
onto the receive pipe.
- if (the current TCP connection is lost) then
- This object goes back to attempting to make a TCP connection with the
IP address and port specified by network_parameters.
- if (the receive pipe can contain bridge_status objects) then
- Whenever the bridge's status changes the updated bridge_status will be
enqueued onto the receive pipe unless the change was a TCP disconnect
resulting from a user calling reconfigure(), clear(), or destructing this
bridge. The status contents are defined by get_bridge_status().
!*/
template <typename T, typename R>
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_receive_decoration<R> receive_pipe,
bridge_transmit_decoration<T> transmit_pipe
);
/*!
ensures
- performs reconfigure(network_parameters, transmit_pipe, receive_pipe)
!*/
template <typename T>
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_transmit_decoration<T> transmit_pipe
);
/*!
ensures
- This function is identical to the above two reconfigure() functions
except that there is no receive pipe.
!*/
template <typename R>
void reconfigure (
connect_to_ip_and_port network_parameters,
bridge_receive_decoration<R> receive_pipe
);
/*!
ensures
- This function is identical to the above three reconfigure() functions
except that there is no transmit pipe.
!*/
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BRIDGe_ABSTRACT_
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BSPh_
#define DLIB_BSPh_
#include "bsp/bsp.h"
#endif // DLIB_BSPh_
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BSP_CPph_
#define DLIB_BSP_CPph_
#include "bsp.h"
#include <memory>
#include <stack>
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace dlib
{
namespace impl1
{
void connect_all (
map_id_to_con& cons,
const std::vector<network_address>& hosts,
unsigned long node_id
)
{
cons.clear();
for (unsigned long i = 0; i < hosts.size(); ++i)
{
std::unique_ptr<bsp_con> con(new bsp_con(hosts[i]));
dlib::serialize(node_id, con->stream); // tell the other end our node_id
unsigned long id = i+1;
cons.add(id, con);
}
}
void connect_all_hostinfo (
map_id_to_con& cons,
const std::vector<hostinfo>& hosts,
unsigned long node_id,
std::string& error_string
)
{
cons.clear();
for (unsigned long i = 0; i < hosts.size(); ++i)
{
try
{
std::unique_ptr<bsp_con> con(new bsp_con(hosts[i].addr));
dlib::serialize(node_id, con->stream); // tell the other end our node_id
con->stream.flush();
unsigned long id = hosts[i].node_id;
cons.add(id, con);
}
catch (std::exception&)
{
std::ostringstream sout;
sout << "Could not connect to " << hosts[i].addr;
error_string = sout.str();
break;
}
}
}
void send_out_connection_orders (
map_id_to_con& cons,
const std::vector<network_address>& hosts
)
{
// tell everyone their node ids
cons.reset();
while (cons.move_next())
{
dlib::serialize(cons.element().key(), cons.element().value()->stream);
}
// now tell them who to connect to
std::vector<hostinfo> targets;
for (unsigned long i = 0; i < hosts.size(); ++i)
{
hostinfo info(hosts[i], i+1);
dlib::serialize(targets, cons[info.node_id]->stream);
targets.push_back(info);
// let the other host know how many incoming connections to expect
const unsigned long num = hosts.size()-targets.size();
dlib::serialize(num, cons[info.node_id]->stream);
cons[info.node_id]->stream.flush();
}
}
// ------------------------------------------------------------------------------------
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace impl2
{
// These control bytes are sent before each message between nodes. Note that many
// of these are only sent between the control node (node 0) and the other nodes.
// This is because the controller node is responsible for handling the
// synchronization that needs to happen when all nodes block on calls to
// receive_data()
// at the same time.
// denotes a normal content message.
const static char MESSAGE_HEADER = 0;
// sent to the controller node when someone receives a message via receive_data().
const static char GOT_MESSAGE = 1;
// sent to the controller node when someone sends a message via send().
const static char SENT_MESSAGE = 2;
// sent to the controller node when someone enters a call to receive_data()
const static char IN_WAITING_STATE = 3;
// broadcast when a node terminates itself.
const static char NODE_TERMINATE = 5;
// broadcast by the controller node when it determines that all nodes are blocked
// on calls to receive_data() and there aren't any messages in flight. This is also
// what makes us go to the next epoch.
const static char SEE_ALL_IN_WAITING_STATE = 6;
// This isn't ever transmitted between nodes. It is used internally to indicate
// that an error occurred.
const static char READ_ERROR = 7;
// ------------------------------------------------------------------------------------
void read_thread (
impl1::bsp_con* con,
unsigned long node_id,
unsigned long sender_id,
impl1::thread_safe_message_queue& msg_buffer
)
{
try
{
while(true)
{
impl1::msg_data msg;
deserialize(msg.msg_type, con->stream);
msg.sender_id = sender_id;
if (msg.msg_type == MESSAGE_HEADER)
{
msg.data.reset(new std::vector<char>);
deserialize(msg.epoch, con->stream);
deserialize(*msg.data, con->stream);
}
msg_buffer.push_and_consume(msg);
if (msg.msg_type == NODE_TERMINATE)
break;
}
}
catch (std::exception& e)
{
impl1::msg_data msg;
msg.data.reset(new std::vector<char>);
vectorstream sout(*msg.data);
sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
sout << " Receiving processing node id: " << node_id << std::endl;
sout << " Error message in the exception: " << e.what() << std::endl;
msg.sender_id = sender_id;
msg.msg_type = READ_ERROR;
msg_buffer.push_and_consume(msg);
}
catch (...)
{
impl1::msg_data msg;
msg.data.reset(new std::vector<char>);
vectorstream sout(*msg.data);
sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
sout << " Receiving processing node id: " << node_id << std::endl;
msg.sender_id = sender_id;
msg.msg_type = READ_ERROR;
msg_buffer.push_and_consume(msg);
}
}
// ------------------------------------------------------------------------------------
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// IMPLEMENTATION OF bsp_context OBJECT MEMBERS
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void bsp_context::
close_all_connections_gracefully(
)
{
if (node_id() != 0)
{
_cons.reset();
while (_cons.move_next())
{
// tell the other end that we are intentionally dropping the connection
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush();
}
}
impl1::msg_data msg;
// now wait for all the other nodes to terminate
while (num_terminated_nodes < _cons.size() )
{
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
{
num_waiting_nodes = 0;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
++current_epoch;
}
if (!msg_buffer.pop(msg))
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
if (msg.msg_type == impl2::NODE_TERMINATE)
{
++num_terminated_nodes;
_cons[msg.sender_id]->terminated = true;
}
else if (msg.msg_type == impl2::READ_ERROR)
{
throw dlib::socket_error(msg.data_to_string());
}
else if (msg.msg_type == impl2::MESSAGE_HEADER)
{
throw dlib::socket_error("A BSP node received a message after it has terminated.");
}
else if (msg.msg_type == impl2::GOT_MESSAGE)
{
--num_waiting_nodes;
--outstanding_messages;
}
else if (msg.msg_type == impl2::SENT_MESSAGE)
{
++outstanding_messages;
}
else if (msg.msg_type == impl2::IN_WAITING_STATE)
{
++num_waiting_nodes;
}
}
if (node_id() == 0)
{
_cons.reset();
while (_cons.move_next())
{
// tell the other end that we are intentionally dropping the connection
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
_cons.element().value()->stream.flush();
}
if (outstanding_messages != 0)
{
std::ostringstream sout;
sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n";
sout << "have a corresponding call to receive().";
throw dlib::socket_error(sout.str());
}
}
}
// ----------------------------------------------------------------------------------------
bsp_context::
~bsp_context()
{
_cons.reset();
while (_cons.move_next())
{
_cons.element().value()->con->shutdown();
}
msg_buffer.disable();
// this will wait for all the threads to terminate
threads.clear();
}
// ----------------------------------------------------------------------------------------
bsp_context::
bsp_context(
unsigned long node_id_,
impl1::map_id_to_con& cons_
) :
outstanding_messages(0),
num_waiting_nodes(0),
num_terminated_nodes(0),
current_epoch(1),
_cons(cons_),
_node_id(node_id_)
{
// spawn a bunch of read threads, one for each connection
_cons.reset();
while (_cons.move_next())
{
std::unique_ptr<thread_function> ptr(new thread_function(&impl2::read_thread,
_cons.element().value().get(),
_node_id,
_cons.element().key(),
ref(msg_buffer)));
threads.push_back(ptr);
}
}
// ----------------------------------------------------------------------------------------
bool bsp_context::
receive_data (
std::shared_ptr<std::vector<char> >& item,
unsigned long& sending_node_id
)
{
notify_control_node(impl2::IN_WAITING_STATE);
while (true)
{
// If there aren't any nodes left to give us messages then return right now.
// We need to check the msg_buffer size to make sure there aren't any
// unprocessed message there. Recall that this can happen because status
// messages always jump to the front of the message buffer. So we might have
// learned about the node terminations before processing their messages for us.
if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
{
return false;
}
// if all running nodes are currently blocking forever on receive_data()
if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
{
num_waiting_nodes = 0;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
// Note that the reason we have this epoch counter is so we can tell if a
// sent message is from before or after one of these "all nodes waiting"
// synchronization events. If we didn't have the epoch count we would have
// a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
// message before others and then sends out a message to another node
// before that node got the SEE_ALL_IN_WAITING_STATE message. Then that
// node would think the normal message came before SEE_ALL_IN_WAITING_STATE
// which would be bad.
++current_epoch;
return false;
}
impl1::msg_data data;
if (!msg_buffer.pop(data, current_epoch))
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
switch(data.msg_type)
{
case impl2::MESSAGE_HEADER: {
item = data.data;
sending_node_id = data.sender_id;
notify_control_node(impl2::GOT_MESSAGE);
return true;
} break;
case impl2::IN_WAITING_STATE: {
++num_waiting_nodes;
} break;
case impl2::GOT_MESSAGE: {
--outstanding_messages;
--num_waiting_nodes;
} break;
case impl2::SENT_MESSAGE: {
++outstanding_messages;
} break;
case impl2::NODE_TERMINATE: {
++num_terminated_nodes;
_cons[data.sender_id]->terminated = true;
} break;
case impl2::SEE_ALL_IN_WAITING_STATE: {
++current_epoch;
return false;
} break;
case impl2::READ_ERROR: {
throw dlib::socket_error(data.data_to_string());
} break;
default: {
throw dlib::socket_error("Unknown message received by dlib::bsp_context");
} break;
} // end switch()
} // end while (true)
}
// ----------------------------------------------------------------------------------------
void bsp_context::
notify_control_node (
char val
)
{
if (node_id() == 0)
{
using namespace impl2;
switch(val)
{
case SENT_MESSAGE: {
++outstanding_messages;
} break;
case GOT_MESSAGE: {
--outstanding_messages;
} break;
case IN_WAITING_STATE: {
// nothing to do in this case
} break;
default:
DLIB_CASSERT(false,"This should never happen");
}
}
else
{
serialize(val, _cons[0]->stream);
_cons[0]->stream.flush();
}
}
// ----------------------------------------------------------------------------------------
void bsp_context::
broadcast_byte (
char val
)
{
for (unsigned long i = 0; i < number_of_nodes(); ++i)
{
// don't send to yourself or to terminated nodes
if (i == node_id() || _cons[i]->terminated)
continue;
serialize(val, _cons[i]->stream);
_cons[i]->stream.flush();
}
}
// ----------------------------------------------------------------------------------------
void bsp_context::
send_data(
const std::vector<char>& item,
unsigned long target_node_id
)
{
using namespace impl2;
if (_cons[target_node_id]->terminated)
throw socket_error("Attempt to send a message to a node that has terminated.");
serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
serialize(current_epoch, _cons[target_node_id]->stream);
serialize(item, _cons[target_node_id]->stream);
_cons[target_node_id]->stream.flush();
notify_control_node(SENT_MESSAGE);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BSP_CPph_
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BsP_Hh_
#define DLIB_BsP_Hh_
#include "bsp_abstract.h"
#include <memory>
#include <queue>
#include <vector>
#include "../sockets.h"
#include "../array.h"
#include "../sockstreambuf.h"
#include "../string.h"
#include "../serialize.h"
#include "../map.h"
#include "../ref.h"
#include "../vectorstream.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl1
{
inline void null_notify(
unsigned short
) {}
struct bsp_con
{
bsp_con(
const network_address& dest
) :
con(connect(dest)),
buf(con),
stream(&buf),
terminated(false)
{
con->disable_nagle();
}
bsp_con(
std::unique_ptr<connection>& conptr
) :
buf(conptr),
stream(&buf),
terminated(false)
{
// make sure we own the connection
conptr.swap(con);
con->disable_nagle();
}
std::unique_ptr<connection> con;
sockstreambuf buf;
std::iostream stream;
bool terminated;
};
typedef dlib::map<unsigned long, std::unique_ptr<bsp_con> >::kernel_1a_c map_id_to_con;
void connect_all (
map_id_to_con& cons,
const std::vector<network_address>& hosts,
unsigned long node_id
);
/*!
ensures
- creates connections to all the given hosts and stores them into cons
!*/
void send_out_connection_orders (
map_id_to_con& cons,
const std::vector<network_address>& hosts
);
// ------------------------------------------------------------------------------------
struct hostinfo
{
hostinfo() {}
hostinfo (
const network_address& addr_,
unsigned long node_id_
) :
addr(addr_),
node_id(node_id_)
{
}
network_address addr;
unsigned long node_id;
};
inline void serialize (
const hostinfo& item,
std::ostream& out
)
{
dlib::serialize(item.addr, out);
dlib::serialize(item.node_id, out);
}
inline void deserialize (
hostinfo& item,
std::istream& in
)
{
dlib::deserialize(item.addr, in);
dlib::deserialize(item.node_id, in);
}
// ------------------------------------------------------------------------------------
void connect_all_hostinfo (
map_id_to_con& cons,
const std::vector<hostinfo>& hosts,
unsigned long node_id,
std::string& error_string
);
// ------------------------------------------------------------------------------------
template <
typename port_notify_function_type
>
void listen_and_connect_all(
unsigned long& node_id,
map_id_to_con& cons,
unsigned short port,
port_notify_function_type port_notify_function
)
{
cons.clear();
std::unique_ptr<listener> list;
const int status = create_listener(list, port);
if (status == PORTINUSE)
{
throw socket_error("Unable to create listening port " + cast_to_string(port) +
". The port is already in use");
}
else if (status != 0)
{
throw socket_error("Unable to create listening port " + cast_to_string(port) );
}
port_notify_function(list->get_listening_port());
std::unique_ptr<connection> con;
if (list->accept(con))
{
throw socket_error("Error occurred while accepting new connection");
}
std::unique_ptr<bsp_con> temp(new bsp_con(con));
unsigned long remote_node_id;
dlib::deserialize(remote_node_id, temp->stream);
dlib::deserialize(node_id, temp->stream);
std::vector<hostinfo> targets;
dlib::deserialize(targets, temp->stream);
unsigned long num_incoming_connections;
dlib::deserialize(num_incoming_connections, temp->stream);
cons.add(remote_node_id,temp);
// make a thread that will connect to all the targets
map_id_to_con cons2;
std::string error_string;
thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string));
if (error_string.size() != 0)
throw socket_error(error_string);
// accept any incoming connections
for (unsigned long i = 0; i < num_incoming_connections; ++i)
{
// If it takes more than 10 seconds for the other nodes to connect to us
// then something has gone horribly wrong and it almost certainly will
// never connect at all. So just give up if that happens.
const unsigned long timeout_milliseconds = 10000;
if (list->accept(con, timeout_milliseconds))
{
throw socket_error("Error occurred while accepting new connection");
}
temp.reset(new bsp_con(con));
dlib::deserialize(remote_node_id, temp->stream);
cons.add(remote_node_id,temp);
}
// put all the connections created by the thread into cons
thread.wait();
while (cons2.size() > 0)
{
unsigned long id;
std::unique_ptr<bsp_con> temp;
cons2.remove_any(id,temp);
cons.add(id,temp);
}
}
// ------------------------------------------------------------------------------------
struct msg_data
{
std::shared_ptr<std::vector<char> > data;
unsigned long sender_id;
char msg_type;
dlib::uint64 epoch;
msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {}
std::string data_to_string() const
{
if (data && data->size() != 0)
return std::string(&(*data)[0], data->size());
else
return "";
}
};
// ------------------------------------------------------------------------------------
class thread_safe_message_queue : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a simple message queue for msg_data objects. Note that it
has the special property that, while messages will generally leave
the queue in the order they are inserted, any message with a smaller
epoch value will always be popped out first. But for all messages
with equal epoch values the queue functions as a normal FIFO queue.
!*/
private:
struct msg_wrap
{
msg_wrap(
const msg_data& data_,
const dlib::uint64& sequence_number_
) : data(data_), sequence_number(sequence_number_) {}
msg_wrap() : sequence_number(0){}
msg_data data;
dlib::uint64 sequence_number;
// Make it so that when msg_wrap objects are in a std::priority_queue,
// messages with a smaller epoch number always come first. Then, within an
// epoch, messages are ordered by their sequence number (so smaller first
// there as well).
bool operator<(const msg_wrap& item) const
{
if (data.epoch < item.data.epoch)
{
return false;
}
else if (data.epoch > item.data.epoch)
{
return true;
}
else
{
if (sequence_number < item.sequence_number)
return false;
else
return true;
}
}
};
public:
thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {}
~thread_safe_message_queue()
{
disable();
}
void disable()
{
auto_mutex lock(class_mutex);
disabled = true;
sig.broadcast();
}
unsigned long size() const
{
auto_mutex lock(class_mutex);
return data.size();
}
void push_and_consume( msg_data& item)
{
auto_mutex lock(class_mutex);
data.push(msg_wrap(item, next_seq_num++));
// do this here so that we don't have to worry about different threads touching the shared_ptr.
item.data.reset();
sig.signal();
}
bool pop (
msg_data& item
)
/*!
ensures
- if (this function returns true) then
- #item == the next thing from the queue
- else
- this object is disabled
!*/
{
auto_mutex lock(class_mutex);
while (data.size() == 0 && !disabled)
sig.wait();
if (disabled)
return false;
item = data.top().data;
data.pop();
return true;
}
bool pop (
msg_data& item,
const dlib::uint64& max_epoch
)
/*!
ensures
- if (this function returns true) then
- #item == the next thing from the queue that has an epoch <= max_epoch
- else
- this object is disabled
!*/
{
auto_mutex lock(class_mutex);
while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled)
sig.wait();
if (disabled)
return false;
item = data.top().data;
data.pop();
return true;
}
private:
std::priority_queue<msg_wrap> data;
dlib::mutex class_mutex;
dlib::signaler sig;
bool disabled;
dlib::uint64 next_seq_num;
};
}
// ----------------------------------------------------------------------------------------
class bsp_context : noncopyable
{
public:
template <typename T>
void send(
const T& item,
unsigned long target_node_id
)
{
// make sure requires clause is not broken
DLIB_CASSERT(target_node_id < number_of_nodes() &&
target_node_id != node_id(),
"\t void bsp_context::send()"
<< "\n\t Invalid arguments were given to this function."
<< "\n\t target_node_id: " << target_node_id
<< "\n\t node_id(): " << node_id()
<< "\n\t number_of_nodes(): " << number_of_nodes()
<< "\n\t this: " << this
);
std::vector<char> buf;
vectorstream sout(buf);
serialize(item, sout);
send_data(buf, target_node_id);
}
template <typename T>
void broadcast (
const T& item
)
{
std::vector<char> buf;
vectorstream sout(buf);
serialize(item, sout);
for (unsigned long i = 0; i < number_of_nodes(); ++i)
{
// Don't send to yourself.
if (i == node_id())
continue;
send_data(buf, i);
}
}
unsigned long node_id (
) const { return _node_id; }
unsigned long number_of_nodes (
) const { return _cons.size()+1; }
void receive (
)
{
unsigned long id;
std::shared_ptr<std::vector<char> > temp;
if (receive_data(temp,id))
throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message.");
}
template <typename T>
void receive (
T& item
)
{
if(!try_receive(item))
throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked.");
}
template <typename T>
bool try_receive (
T& item
)
{
unsigned long sending_node_id;
return try_receive(item, sending_node_id);
}
template <typename T>
void receive (
T& item,
unsigned long& sending_node_id
)
{
if(!try_receive(item, sending_node_id))
throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked.");
}
template <typename T>
bool try_receive (
T& item,
unsigned long& sending_node_id
)
{
std::shared_ptr<std::vector<char> > temp;
if (receive_data(temp, sending_node_id))
{
vectorstream sin(*temp);
deserialize(item, sin);
if (sin.peek() != EOF)
throw serialization_error("deserialize() did not consume all bytes produced by serialize(). "
"This probably means you are calling a receive method with a different type "
"of object than the one which was sent.");
return true;
}
else
{
return false;
}
}
~bsp_context();
private:
bsp_context();
bsp_context(
unsigned long node_id_,
impl1::map_id_to_con& cons_
);
void close_all_connections_gracefully();
/*!
ensures
- closes all the connections to other nodes and lets them know that
we are terminating normally rather than as the result of some kind
of error.
!*/
bool receive_data (
std::shared_ptr<std::vector<char> >& item,
unsigned long& sending_node_id
);
void notify_control_node (
char val
);
void broadcast_byte (
char val
);
void send_data(
const std::vector<char>& item,
unsigned long target_node_id
);
/*!
requires
- target_node_id < number_of_nodes()
- target_node_id != node_id()
ensures
- sends a copy of item to the node with the given id.
!*/
unsigned long outstanding_messages;
unsigned long num_waiting_nodes;
unsigned long num_terminated_nodes;
dlib::uint64 current_epoch;
impl1::thread_safe_message_queue msg_buffer;
impl1::map_id_to_con& _cons;
const unsigned long _node_id;
array<std::unique_ptr<thread_function> > threads;
// -----------------------------------
template <
typename funct_type
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct
);
template <
typename funct_type,
typename ARG1
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1
);
template <
typename funct_type,
typename ARG1,
typename ARG2
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2
);
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
);
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
);
// -----------------------------------
template <
typename port_notify_function_type,
typename funct_type
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct
);
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1
);
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2
);
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
);
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
);
// -----------------------------------
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename funct_type
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj,arg1);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2,arg3);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2,arg3,arg4);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename funct_type
>
void bsp_listen (
unsigned short listening_port,
funct_type funct
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct);
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1);
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2);
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3);
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj,arg1);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2,arg3);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2,arg3,arg4);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
}
#ifdef NO_MAKEFILE
#include "bsp.cpp"
#endif
#endif // DLIB_BsP_Hh_
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_BsP_ABSTRACT_Hh_
#ifdef DLIB_BsP_ABSTRACT_Hh_
#include "../noncopyable.h"
#include "../sockets/sockets_extensions_abstract.h"
#include <vector>
namespace dlib
{
// ----------------------------------------------------------------------------------------
class bsp_context : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a tool used to implement algorithms using the Bulk Synchronous
Parallel (BSP) computing model. A BSP algorithm is composed of a number of
processing nodes, each executing in parallel. The general flow of
execution in each processing node is the following:
1. Do work locally on some data.
2. Send some messages to other nodes.
3. Receive messages from other nodes.
4. Go to step 1 or terminate if complete.
To do this, each processing node needs an API used to send and receive
messages. This API is implemented by the bsp_connect object which provides
these services to a BSP node.
Note that BSP processing nodes are spawned using the bsp_connect() and
bsp_listen() routines defined at the bottom of this file. For example, to
start a BSP algorithm consisting of N processing nodes, you would make N-1
calls to bsp_listen() and one call to bsp_connect(). The call to
bsp_connect() then initiates the computation on all nodes.
Finally, note that there is no explicit barrier synchronization function
you call at the end of step 3. Instead, you can simply call a method such
as try_receive() until it returns false. That is, the bsp_context's
receive methods incorporate a barrier synchronization that happens once all
the BSP nodes are blocked on receive calls and there are no more messages
in flight.
THREAD SAFETY
This object is not thread-safe. In particular, you should only ever have
one thread that works with an instance of this object. This means that,
for example, you should not spawn sub-threads from within a BSP processing
node and have them invoke methods on this object. Instead, you should only
invoke this object's methods from within the BSP processing node's main
thread (i.e. the thread that executes the user supplied function funct()).
!*/
public:
template <typename T>
void send(
const T& item,
unsigned long target_node_id
);
/*!
requires
- item is serializable
- target_node_id < number_of_nodes()
- target_node_id != node_id()
ensures
- sends a copy of item to the node with the given id.
throws
- dlib::socket_error:
This exception is thrown if there is an error which prevents us from
delivering the message to the given node. One way this might happen is
if the target node has already terminated its execution or has lost
network connectivity.
!*/
template <typename T>
void broadcast (
const T& item
);
/*!
ensures
- item is serializable
- sends a copy of item to all other processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents us from
delivering a message to one of the other nodes. This might happen, for
example, if one of the nodes has terminated its execution or has lost
network connectivity.
!*/
unsigned long node_id (
) const;
/*!
ensures
- Returns the id of the current processing node. That is,
returns a number N such that:
- N < number_of_nodes()
- N == the node id of the processing node that called node_id(). This
is a number that uniquely identifies the processing node.
!*/
unsigned long number_of_nodes (
) const;
/*!
ensures
- returns the number of processing nodes participating in the BSP
computation.
!*/
template <typename T>
bool try_receive (
T& item
);
/*!
requires
- item is serializable
ensures
- if (this function returns true) then
- #item == the next message which was sent to the calling processing
node.
- else
- The following must have been true for this function to return false:
- All other nodes were blocked on calls to receive(),
try_receive(), or have terminated.
- There were not any messages in flight between any nodes.
- That is, if all the nodes had continued to block on receive
methods then they all would have blocked forever. Therefore,
this function only returns false once there are no more messages
to process by any node and there is no possibility of more being
generated until control is returned to the callers of receive
methods.
- When one BSP node's receive method returns because of the above
conditions then all of them will also return. That is, it is NOT the
case that just a subset of BSP nodes unblock. Moreover, they all
unblock at the same time.
throws
- dlib::socket_error:
This exception is thrown if some error occurs which prevents us from
communicating with other processing nodes.
- dlib::serialization_error or any exception thrown by the global
deserialize(T) routine:
This is thrown if there is a problem in deserialize(). This might
happen if the message sent doesn't match the type T expected by
try_receive().
!*/
template <typename T>
void receive (
T& item
);
/*!
requires
- item is serializable
ensures
- #item == the next message which was sent to the calling processing
node.
- This function is just a wrapper around try_receive() that throws an
exception if a message is not received (i.e. if try_receive() returns
false).
throws
- dlib::socket_error:
This exception is thrown if some error occurs which prevents us from
communicating with other processing nodes or if there was not a message
to receive.
- dlib::serialization_error or any exception thrown by the global
deserialize(T) routine:
This is thrown if there is a problem in deserialize(). This might
happen if the message sent doesn't match the type T expected by
receive().
!*/
template <typename T>
bool try_receive (
T& item,
unsigned long& sending_node_id
);
/*!
requires
- item is serializable
ensures
- if (this function returns true) then
- #item == the next message which was sent to the calling processing
node.
- #sending_node_id == the node id of the node that sent this message.
- #sending_node_id < number_of_nodes()
- else
- The following must have been true for this function to return false:
- All other nodes were blocked on calls to receive(),
try_receive(), or have terminated.
- There were not any messages in flight between any nodes.
- That is, if all the nodes had continued to block on receive
methods then they all would have blocked forever. Therefore,
this function only returns false once there are no more messages
to process by any node and there is no possibility of more being
generated until control is returned to the callers of receive
methods.
- When one BSP node's receive method returns because of the above
conditions then all of them will also return. That is, it is NOT the
case that just a subset of BSP nodes unblock. Moreover, they all
unblock at the same time.
throws
- dlib::socket_error:
This exception is thrown if some error occurs which prevents us from
communicating with other processing nodes.
- dlib::serialization_error or any exception thrown by the global
deserialize(T) routine:
This is thrown if there is a problem in deserialize(). This might
happen if the message sent doesn't match the type T expected by
try_receive().
!*/
template <typename T>
void receive (
T& item,
unsigned long& sending_node_id
);
/*!
requires
- item is serializable
ensures
- #item == the next message which was sent to the calling processing node.
- #sending_node_id == the node id of the node that sent this message.
- #sending_node_id < number_of_nodes()
- This function is just a wrapper around try_receive() that throws an
exception if a message is not received (i.e. if try_receive() returns
false).
throws
- dlib::socket_error:
This exception is thrown if some error occurs which prevents us from
communicating with other processing nodes or if there was not a message
to receive.
- dlib::serialization_error or any exception thrown by the global
deserialize(T) routine:
This is thrown if there is a problem in deserialize(). This might
happen if the message sent doesn't match the type T expected by
receive().
!*/
void receive (
);
/*!
ensures
- Waits for the following to all be true:
- All other nodes were blocked on calls to receive(), try_receive(), or
have terminated.
- There are not any messages in flight between any nodes.
- That is, if all the nodes had continued to block on receive methods
then they all would have blocked forever. Therefore, this function
only returns once there are no more messages to process by any node
and there is no possibility of more being generated until control is
returned to the callers of receive methods.
- When one BSP node's receive method returns because of the above
conditions then all of them will also return. That is, it is NOT the
case that just a subset of BSP nodes unblock. Moreover, they all unblock
at the same time.
throws
- dlib::socket_error:
This exception is thrown if some error occurs which prevents us from
communicating with other processing nodes or if a message is received
before this function would otherwise return.
!*/
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename funct_type
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
- The processing node with a node ID of 0 will run locally on the machine
calling bsp_connect(). In particular, this node will execute funct(CONTEXT),
which is expected to carry out this node's portion of the BSP computation.
- The other processing nodes are executed on the hosts indicated by the input
argument. In particular, this function interprets hosts as a list addresses
identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
routines.
- This call to bsp_connect() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
- The processing node with a node ID of 0 will run locally on the machine
calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1),
which is expected to carry out this node's portion of the BSP computation.
- The other processing nodes are executed on the hosts indicated by the input
argument. In particular, this function interprets hosts as a list addresses
identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
routines.
- This call to bsp_connect() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
- The processing node with a node ID of 0 will run locally on the machine
calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2),
which is expected to carry out this node's portion of the BSP computation.
- The other processing nodes are executed on the hosts indicated by the input
argument. In particular, this function interprets hosts as a list addresses
identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
routines.
- This call to bsp_connect() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2,arg3) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
- The processing node with a node ID of 0 will run locally on the machine
calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3),
which is expected to carry out this node's portion of the BSP computation.
- The other processing nodes are executed on the hosts indicated by the input
argument. In particular, this function interprets hosts as a list addresses
identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
routines.
- This call to bsp_connect() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function spawns a BSP job consisting of hosts.size()+1 processing nodes.
- The processing node with a node ID of 0 will run locally on the machine
calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3,arg4),
which is expected to carry out this node's portion of the BSP computation.
- The other processing nodes are executed on the hosts indicated by the input
argument. In particular, this function interprets hosts as a list addresses
identifying machines running the bsp_listen() or bsp_listen_dynamic_port()
routines.
- This call to bsp_connect() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename funct_type
>
void bsp_listen (
unsigned short listening_port,
funct_type funct
);
/*!
requires
- listening_port != 0
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT) will be executed and it will
then be able to participate in the BSP computation as one of the processing
nodes.
- This function will listen on TCP port listening_port for a connection from
bsp_connect(). Once the connection is established, it will close the
listening port so it is free for use by other applications. The connection
and BSP computation will continue uninterrupted.
- This call to bsp_listen() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1
);
/*!
requires
- listening_port != 0
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT,arg1) will be executed and it will
then be able to participate in the BSP computation as one of the processing
nodes.
- This function will listen on TCP port listening_port for a connection from
bsp_connect(). Once the connection is established, it will close the
listening port so it is free for use by other applications. The connection
and BSP computation will continue uninterrupted.
- This call to bsp_listen() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2
);
/*!
requires
- listening_port != 0
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT,arg1,arg2) will be executed and
it will then be able to participate in the BSP computation as one of the
processing nodes.
- This function will listen on TCP port listening_port for a connection from
bsp_connect(). Once the connection is established, it will close the
listening port so it is free for use by other applications. The connection
and BSP computation will continue uninterrupted.
- This call to bsp_listen() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
);
/*!
requires
- listening_port != 0
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2,arg3) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be
executed and it will then be able to participate in the BSP computation as
one of the processing nodes.
- This function will listen on TCP port listening_port for a connection from
bsp_connect(). Once the connection is established, it will close the
listening port so it is free for use by other applications. The connection
and BSP computation will continue uninterrupted.
- This call to bsp_listen() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
);
/*!
requires
- listening_port != 0
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression
(i.e. funct must be a function or function object)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be
executed and it will then be able to participate in the BSP computation as
one of the processing nodes.
- This function will listen on TCP port listening_port for a connection from
bsp_connect(). Once the connection is established, it will close the
listening port so it is free for use by other applications. The connection
and BSP computation will continue uninterrupted.
- This call to bsp_listen() blocks until the BSP computation has completed on
all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT) must be a valid expression
(i.e. funct must be a function or function object)
- port_notify_function((unsigned short) 1234) must be a valid expression
(i.e. port_notify_function() must be a function or function object taking an
unsigned short)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT) will be executed and it will
then be able to participate in the BSP computation as one of the processing
nodes.
- if (listening_port != 0) then
- This function will listen on TCP port listening_port for a connection
from bsp_connect().
- else
- An available TCP port number is automatically selected and this function
will listen on it for a connection from bsp_connect().
- Once a listening port is opened, port_notify_function() is called with the
port number used. This provides a mechanism to find out what listening port
has been used if it is automatically selected. It also allows you to find
out when the routine has begun listening for an incoming connection from
bsp_connect().
- Once a connection is established, we will close the listening port so it is
free for use by other applications. The connection and BSP computation will
continue uninterrupted.
- This call to bsp_listen_dynamic_port() blocks until the BSP computation has
completed on all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1) must be a valid expression
(i.e. funct must be a function or function object)
- port_notify_function((unsigned short) 1234) must be a valid expression
(i.e. port_notify_function() must be a function or function object taking an
unsigned short)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT,arg1) will be executed and it
will then be able to participate in the BSP computation as one of the
processing nodes.
- if (listening_port != 0) then
- This function will listen on TCP port listening_port for a connection
from bsp_connect().
- else
- An available TCP port number is automatically selected and this function
will listen on it for a connection from bsp_connect().
- Once a listening port is opened, port_notify_function() is called with the
port number used. This provides a mechanism to find out what listening port
has been used if it is automatically selected. It also allows you to find
out when the routine has begun listening for an incoming connection from
bsp_connect().
- Once a connection is established, we will close the listening port so it is
free for use by other applications. The connection and BSP computation will
continue uninterrupted.
- This call to bsp_listen_dynamic_port() blocks until the BSP computation has
completed on all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2) must be a valid expression
(i.e. funct must be a function or function object)
- port_notify_function((unsigned short) 1234) must be a valid expression
(i.e. port_notify_function() must be a function or function object taking an
unsigned short)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT,arg1,arg2) will be executed and
it will then be able to participate in the BSP computation as one of the
processing nodes.
- if (listening_port != 0) then
- This function will listen on TCP port listening_port for a connection
from bsp_connect().
- else
- An available TCP port number is automatically selected and this function
will listen on it for a connection from bsp_connect().
- Once a listening port is opened, port_notify_function() is called with the
port number used. This provides a mechanism to find out what listening port
has been used if it is automatically selected. It also allows you to find
out when the routine has begun listening for an incoming connection from
bsp_connect().
- Once a connection is established, we will close the listening port so it is
free for use by other applications. The connection and BSP computation will
continue uninterrupted.
- This call to bsp_listen_dynamic_port() blocks until the BSP computation has
completed on all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2,arg3) must be a valid expression
(i.e. funct must be a function or function object)
- port_notify_function((unsigned short) 1234) must be a valid expression
(i.e. port_notify_function() must be a function or function object taking an
unsigned short)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be
executed and it will then be able to participate in the BSP computation as
one of the processing nodes.
- if (listening_port != 0) then
- This function will listen on TCP port listening_port for a connection
from bsp_connect().
- else
- An available TCP port number is automatically selected and this function
will listen on it for a connection from bsp_connect().
- Once a listening port is opened, port_notify_function() is called with the
port number used. This provides a mechanism to find out what listening port
has been used if it is automatically selected. It also allows you to find
out when the routine has begun listening for an incoming connection from
bsp_connect().
- Once a connection is established, we will close the listening port so it is
free for use by other applications. The connection and BSP computation will
continue uninterrupted.
- This call to bsp_listen_dynamic_port() blocks until the BSP computation has
completed on all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
);
/*!
requires
- let CONTEXT be an instance of a bsp_context object. Then:
- funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression
(i.e. funct must be a function or function object)
- port_notify_function((unsigned short) 1234) must be a valid expression
(i.e. port_notify_function() must be a function or function object taking an
unsigned short)
ensures
- This function listens for a connection from the bsp_connect() routine. Once
this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be
executed and it will then be able to participate in the BSP computation as
one of the processing nodes.
- if (listening_port != 0) then
- This function will listen on TCP port listening_port for a connection
from bsp_connect().
- else
- An available TCP port number is automatically selected and this function
will listen on it for a connection from bsp_connect().
- Once a listening port is opened, port_notify_function() is called with the
port number used. This provides a mechanism to find out what listening port
has been used if it is automatically selected. It also allows you to find
out when the routine has begun listening for an incoming connection from
bsp_connect().
- Once a connection is established, we will close the listening port so it is
free for use by other applications. The connection and BSP computation will
continue uninterrupted.
- This call to bsp_listen_dynamic_port() blocks until the BSP computation has
completed on all processing nodes.
throws
- dlib::socket_error
This exception is thrown if there is an error which prevents the BSP
job from executing.
- Any exception thrown by funct() will be propagated out of this call to
bsp_connect().
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BsP_ABSTRACT_Hh_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BYTE_ORDEREr_
#define DLIB_BYTE_ORDEREr_
#include "byte_orderer/byte_orderer_kernel_1.h"
#endif // DLIB_BYTE_ORDEREr_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BYTE_ORDEREr_KERNEL_1_
#define DLIB_BYTE_ORDEREr_KERNEL_1_
#include "byte_orderer_kernel_abstract.h"
#include "../algs.h"
#include "../assert.h"
namespace dlib
{
class byte_orderer
{
/*!
INITIAL VALUE
- if (this machine is little endian) then
- little_endian == true
- else
- little_endian == false
CONVENTION
- host_is_big_endian() == !little_endian
- host_is_little_endian() == little_endian
- if (this machine is little endian) then
- little_endian == true
- else
- little_endian == false
!*/
public:
// this is here for backwards compatibility with older versions of dlib.
typedef byte_orderer kernel_1a;
byte_orderer (
)
{
// This will probably never be false but if it is then it means chars are not 8bits
// on this system. Which is a problem for this object.
COMPILE_TIME_ASSERT(sizeof(short) >= 2);
unsigned long temp = 1;
unsigned char* ptr = reinterpret_cast<unsigned char*>(&temp);
if (*ptr == 1)
little_endian = true;
else
little_endian = false;
}
virtual ~byte_orderer (
){}
bool host_is_big_endian (
) const { return !little_endian; }
bool host_is_little_endian (
) const { return little_endian; }
template <
typename T
>
inline void host_to_network (
T& item
) const
{ if (little_endian) flip(item); }
template <
typename T
>
inline void network_to_host (
T& item
) const { if (little_endian) flip(item); }
template <
typename T
>
void host_to_big (
T& item
) const { if (little_endian) flip(item); }
template <
typename T
>
void big_to_host (
T& item
) const { if (little_endian) flip(item); }
template <
typename T
>
void host_to_little (
T& item
) const { if (!little_endian) flip(item); }
template <
typename T
>
void little_to_host (
T& item
) const { if (!little_endian) flip(item); }
private:
template <
typename T,
size_t size
>
inline void flip (
T (&array)[size]
) const
/*!
ensures
- flips the bytes in every element of this array
!*/
{
for (size_t i = 0; i < size; ++i)
{
flip(array[i]);
}
}
template <
typename T
>
inline void flip (
T& item
) const
/*!
ensures
- reverses the byte ordering in item
!*/
{
DLIB_ASSERT_HAS_STANDARD_LAYOUT(T);
T value;
// If you are getting this as an error then you are probably using
// this object wrong. If you think you aren't then send me (Davis) an
// email and I'll either set you straight or change/remove this check so
// your stuff works :)
COMPILE_TIME_ASSERT(sizeof(T) <= sizeof(long double));
// If you are getting a compile error on this line then it means T is
// a pointer type. It doesn't make any sense to byte swap pointers
// since they have no meaning outside the context of their own process.
// So you probably just forgot to dereference that pointer before passing
// it to this function :)
COMPILE_TIME_ASSERT(is_pointer_type<T>::value == false);
const size_t size = sizeof(T);
unsigned char* const ptr = reinterpret_cast<unsigned char*>(&item);
unsigned char* const ptr_temp = reinterpret_cast<unsigned char*>(&value);
for (size_t i = 0; i < size; ++i)
ptr_temp[size-i-1] = ptr[i];
item = value;
}
bool little_endian;
};
// make flip not do anything at all for chars
template <> inline void byte_orderer::flip<char> ( char& ) const {}
template <> inline void byte_orderer::flip<unsigned char> ( unsigned char& ) const {}
template <> inline void byte_orderer::flip<signed char> ( signed char& ) const {}
}
#endif // DLIB_BYTE_ORDEREr_KERNEL_1_
// Copyright (C) 2006 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_BYTE_ORDEREr_ABSTRACT_
#ifdef DLIB_BYTE_ORDEREr_ABSTRACT_
#include "../algs.h"
namespace dlib
{
class byte_orderer
{
/*!
INITIAL VALUE
This object has no state.
WHAT THIS OBJECT REPRESENTS
This object simply provides a mechanism to convert data from a
host machine's own byte ordering to big or little endian and to
also do the reverse.
It also provides a pair of functions to convert to/from network byte
order where network byte order is big endian byte order. This pair of
functions does the exact same thing as the host_to_big() and big_to_host()
functions and is provided simply so that client code can use the most
self documenting name appropriate.
Also note that this object is capable of correctly flipping the contents
of arrays when the arrays are declared on the stack. e.g. You can
say things like:
int array[10];
bo.host_to_network(array);
!*/
public:
byte_orderer (
);
/*!
ensures
- #*this is properly initialized
throws
- std::bad_alloc
!*/
virtual ~byte_orderer (
);
/*!
ensures
- any resources associated with *this have been released
!*/
bool host_is_big_endian (
) const;
/*!
ensures
- if (the host computer is a big endian machine) then
- returns true
- else
- returns false
!*/
bool host_is_little_endian (
) const;
/*!
ensures
- if (the host computer is a little endian machine) then
- returns true
- else
- returns false
!*/
template <
typename T
>
void host_to_network (
T& item
) const;
/*!
ensures
- #item == the value of item converted from host byte order
to network byte order.
!*/
template <
typename T
>
void network_to_host (
T& item
) const;
/*!
ensures
- #item == the value of item converted from network byte order
to host byte order.
!*/
template <
typename T
>
void host_to_big (
T& item
) const;
/*!
ensures
- #item == the value of item converted from host byte order
to big endian byte order.
!*/
template <
typename T
>
void big_to_host (
T& item
) const;
/*!
ensures
- #item == the value of item converted from big endian byte order
to host byte order.
!*/
template <
typename T
>
void host_to_little (
T& item
) const;
/*!
ensures
- #item == the value of item converted from host byte order
to little endian byte order.
!*/
template <
typename T
>
void little_to_host (
T& item
) const;
/*!
ensures
- #item == the value of item converted from little endian byte order
to host byte order.
!*/
};
}
#endif // DLIB_BYTE_ORDEREr_ABSTRACT_
#include "dlib_include_path_tutorial.txt"
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CLuSTERING_
#define DLIB_CLuSTERING_
#include "clustering/modularity_clustering.h"
#include "clustering/chinese_whispers.h"
#include "clustering/spectral_cluster.h"
#include "clustering/bottom_up_cluster.h"
#include "svm/kkmeans.h"
#endif // DLIB_CLuSTERING_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_
#define DLIB_BOTTOM_uP_CLUSTER_Hh_
#include <queue>
#include <map>
#include "bottom_up_cluster_abstract.h"
#include "../algs.h"
#include "../matrix.h"
#include "../disjoint_subsets.h"
#include "../graph_utils.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace buc_impl
{
inline void merge_sets (
matrix<double>& dists,
unsigned long dest,
unsigned long src
)
{
for (long r = 0; r < dists.nr(); ++r)
dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src));
}
struct compare_dist
{
bool operator() (
const sample_pair& a,
const sample_pair& b
) const
{
return a.distance() > b.distance();
}
};
}
// ----------------------------------------------------------------------------------------
template <
typename EXP
>
unsigned long bottom_up_cluster (
const matrix_exp<EXP>& dists_,
std::vector<unsigned long>& labels,
unsigned long min_num_clusters,
double max_dist = std::numeric_limits<double>::infinity()
)
{
matrix<double> dists = matrix_cast<double>(dists_);
// make sure requires clause is not broken
DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0,
"\t unsigned long bottom_up_cluster()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t dists.nr(): " << dists.nr()
<< "\n\t dists.nc(): " << dists.nc()
<< "\n\t min_num_clusters: " << min_num_clusters
);
using namespace buc_impl;
labels.resize(dists.nr());
disjoint_subsets sets;
sets.set_size(dists.nr());
if (labels.size() == 0)
return 0;
// push all the edges in the graph into a priority queue so the best edges to merge
// come first.
std::priority_queue<sample_pair, std::vector<sample_pair>, compare_dist> que;
for (long r = 0; r < dists.nr(); ++r)
for (long c = r+1; c < dists.nc(); ++c)
que.push(sample_pair(r,c,dists(r,c)));
// Now start merging nodes.
for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter)
{
// find the next best thing to merge.
double best_dist = que.top().distance();
unsigned long a = sets.find_set(que.top().index1());
unsigned long b = sets.find_set(que.top().index2());
que.pop();
// we have been merging and modifying the distances, so make sure this distance
// is still valid and these guys haven't been merged already.
while(a == b || best_dist < dists(a,b))
{
// Haven't merged it yet, so put it back in with updated distance for
// reconsideration later.
if (a != b)
que.push(sample_pair(a, b, dists(a, b)));
best_dist = que.top().distance();
a = sets.find_set(que.top().index1());
b = sets.find_set(que.top().index2());
que.pop();
}
// now merge these sets if the best distance is small enough
if (best_dist > max_dist)
break;
unsigned long news = sets.merge_sets(a,b);
unsigned long olds = (news==a)?b:a;
merge_sets(dists, news, olds);
}
// figure out which cluster each element is in. Also make sure the labels are
// contiguous.
std::map<unsigned long, unsigned long> relabel;
for (unsigned long r = 0; r < labels.size(); ++r)
{
unsigned long l = sets.find_set(r);
// relabel to make contiguous
if (relabel.count(l) == 0)
{
unsigned long next = relabel.size();
relabel[l] = next;
}
labels[r] = relabel[l];
}
return relabel.size();
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
struct snl_range
{
snl_range() = default;
snl_range(double val) : lower(val), upper(val) {}
snl_range(double l, double u) : lower(l), upper(u) { DLIB_ASSERT(lower <= upper)}
double lower = 0;
double upper = 0;
double width() const { return upper-lower; }
bool operator<(const snl_range& item) const { return lower < item.lower; }
};
inline snl_range merge(const snl_range& a, const snl_range& b)
{
return snl_range(std::min(a.lower, b.lower), std::max(a.upper, b.upper));
}
inline double distance (const snl_range& a, const snl_range& b)
{
return std::max(a.lower,b.lower) - std::min(a.upper,b.upper);
}
inline std::ostream& operator<< (std::ostream& out, const snl_range& item )
{
out << "["<<item.lower<<","<<item.upper<<"]";
return out;
}
// ----------------------------------------------------------------------------------------
inline std::vector<snl_range> segment_number_line (
const std::vector<double>& x,
const double max_range_width
)
{
DLIB_CASSERT(max_range_width >= 0);
// create initial ranges, one for each value in x. So initially, all the ranges have
// width of 0.
std::vector<snl_range> ranges;
for (auto v : x)
ranges.push_back(v);
std::sort(ranges.begin(), ranges.end());
std::vector<snl_range> greedy_final_ranges;
if (ranges.size() == 0)
return greedy_final_ranges;
// We will try two different clustering strategies. One that does a simple greedy left
// to right sweep and another that does a bottom up agglomerative clustering. This
// first loop runs the greedy left to right sweep. Then at the end of this routine we
// will return the results that produced the tightest clustering.
greedy_final_ranges.push_back(ranges[0]);
for (size_t i = 1; i < ranges.size(); ++i)
{
auto m = merge(greedy_final_ranges.back(), ranges[i]);
if (m.width() <= max_range_width)
greedy_final_ranges.back() = m;
else
greedy_final_ranges.push_back(ranges[i]);
}
// Here we do the bottom up clustering. So compute the edges connecting our ranges.
// We will simply say there are edges between ranges if and only if they are
// immediately adjacent on the number line.
std::vector<sample_pair> edges;
for (size_t i = 1; i < ranges.size(); ++i)
edges.push_back(sample_pair(i-1,i, distance(ranges[i-1],ranges[i])));
std::sort(edges.begin(), edges.end(), order_by_distance<sample_pair>);
disjoint_subsets sets;
sets.set_size(ranges.size());
// Now start merging nodes.
for (auto edge : edges)
{
// find the next best thing to merge.
unsigned long a = sets.find_set(edge.index1());
unsigned long b = sets.find_set(edge.index2());
// merge it if it doesn't result in an interval that's too big.
auto m = merge(ranges[a], ranges[b]);
if (m.width() <= max_range_width)
{
unsigned long news = sets.merge_sets(a,b);
ranges[news] = m;
}
}
// Now create a list of the final ranges. We will do this by keeping track of which
// range we already added to final_ranges.
std::vector<snl_range> final_ranges;
std::vector<bool> already_output(ranges.size(), false);
for (unsigned long i = 0; i < sets.size(); ++i)
{
auto s = sets.find_set(i);
if (!already_output[s])
{
final_ranges.push_back(ranges[s]);
already_output[s] = true;
}
}
// only use the greedy clusters if they found a clustering with fewer clusters.
// Otherwise, the bottom up clustering probably produced a more sensible clustering.
if (final_ranges.size() <= greedy_final_ranges.size())
return final_ranges;
else
return greedy_final_ranges;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BOTTOM_uP_CLUSTER_Hh_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_
#ifdef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_
#include "../matrix.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename EXP
>
unsigned long bottom_up_cluster (
const matrix_exp<EXP>& dists,
std::vector<unsigned long>& labels,
unsigned long min_num_clusters,
double max_dist = std::numeric_limits<double>::infinity()
);
/*!
requires
- dists.nr() == dists.nc()
- min_num_clusters > 0
- dists == trans(dists)
(l.e. dists should be symmetric)
ensures
- Runs a bottom up agglomerative clustering algorithm.
- Interprets dists as a matrix that gives the distances between dists.nr()
items. In particular, we take dists(i,j) to be the distance between the ith
and jth element of some set. This function clusters the elements of this set
into at least min_num_clusters (or dists.nr() if there aren't enough
elements). Additionally, within each cluster, the maximum pairwise distance
between any two cluster elements is <= max_dist.
- returns the number of clusters found.
- #labels.size() == dists.nr()
- for all valid i:
- #labels[i] == the cluster ID of the node with index i (i.e. the node
corresponding to the distances dists(i,*)).
- 0 <= #labels[i] < the number of clusters found
(i.e. cluster IDs are assigned contiguously and start at 0)
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
struct snl_range
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents an interval on the real number line. It is used
to store the outputs of the segment_number_line() routine defined below.
!*/
snl_range(
);
/*!
ensures
- #lower == 0
- #upper == 0
!*/
snl_range(
double val
);
/*!
ensures
- #lower == val
- #upper == val
!*/
snl_range(
double l,
double u
);
/*!
requires
- l <= u
ensures
- #lower == l
- #upper == u
!*/
double lower;
double upper;
double width(
) const { return upper-lower; }
/*!
ensures
- returns the width of this interval on the number line.
!*/
bool operator<(const snl_range& item) const { return lower < item.lower; }
/*!
ensures
- provides a total ordering of snl_range objects assuming they are
non-overlapping.
!*/
};
std::ostream& operator<< (std::ostream& out, const snl_range& item );
/*!
ensures
- prints item to out in the form [lower,upper].
!*/
// ----------------------------------------------------------------------------------------
std::vector<snl_range> segment_number_line (
const std::vector<double>& x,
const double max_range_width
);
/*!
requires
- max_range_width >= 0
ensures
- Finds a clustering of the values in x and returns the ranges that define the
clustering. This routine uses a combination of bottom up clustering and a
simple greedy scan to try and find the most compact set of ranges that
contain all the values in x.
- This routine has approximately linear runtime.
- Every value in x will be contained inside one of the returned snl_range
objects;
- All returned snl_range object's will have a width() <= max_range_width and
will also be non-overlapping.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CHINESE_WHISPErS_Hh_
#define DLIB_CHINESE_WHISPErS_Hh_
#include "chinese_whispers_abstract.h"
#include <vector>
#include "../rand.h"
#include "../graph_utils/edge_list_graphs.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
inline unsigned long chinese_whispers (
const std::vector<ordered_sample_pair>& edges,
std::vector<unsigned long>& labels,
const unsigned long num_iterations,
dlib::rand& rnd
)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_ordered_by_index(edges),
"\t unsigned long chinese_whispers()"
<< "\n\t Invalid inputs were given to this function"
);
labels.clear();
if (edges.size() == 0)
return 0;
std::vector<std::pair<unsigned long, unsigned long> > neighbors;
find_neighbor_ranges(edges, neighbors);
// Initialize the labels, each node gets a different label.
labels.resize(neighbors.size());
for (unsigned long i = 0; i < labels.size(); ++i)
labels[i] = i;
for (unsigned long iter = 0; iter < neighbors.size()*num_iterations; ++iter)
{
// Pick a random node.
const unsigned long idx = rnd.get_random_64bit_number()%neighbors.size();
// Count how many times each label happens amongst our neighbors.
std::map<unsigned long, double> labels_to_counts;
const unsigned long end = neighbors[idx].second;
for (unsigned long i = neighbors[idx].first; i != end; ++i)
{
labels_to_counts[labels[edges[i].index2()]] += edges[i].distance();
}
// find the most common label
std::map<unsigned long, double>::iterator i;
double best_score = -std::numeric_limits<double>::infinity();
unsigned long best_label = labels[idx];
for (i = labels_to_counts.begin(); i != labels_to_counts.end(); ++i)
{
if (i->second > best_score)
{
best_score = i->second;
best_label = i->first;
}
}
labels[idx] = best_label;
}
// Remap the labels into a contiguous range. First we find the
// mapping.
std::map<unsigned long,unsigned long> label_remap;
for (unsigned long i = 0; i < labels.size(); ++i)
{
const unsigned long next_id = label_remap.size();
if (label_remap.count(labels[i]) == 0)
label_remap[labels[i]] = next_id;
}
// now apply the mapping to all the labels.
for (unsigned long i = 0; i < labels.size(); ++i)
{
labels[i] = label_remap[labels[i]];
}
return label_remap.size();
}
// ----------------------------------------------------------------------------------------
inline unsigned long chinese_whispers (
const std::vector<sample_pair>& edges,
std::vector<unsigned long>& labels,
const unsigned long num_iterations,
dlib::rand& rnd
)
{
std::vector<ordered_sample_pair> oedges;
convert_unordered_to_ordered(edges, oedges);
std::sort(oedges.begin(), oedges.end(), &order_by_index<ordered_sample_pair>);
return chinese_whispers(oedges, labels, num_iterations, rnd);
}
// ----------------------------------------------------------------------------------------
inline unsigned long chinese_whispers (
const std::vector<sample_pair>& edges,
std::vector<unsigned long>& labels,
const unsigned long num_iterations = 100
)
{
dlib::rand rnd;
return chinese_whispers(edges, labels, num_iterations, rnd);
}
// ----------------------------------------------------------------------------------------
inline unsigned long chinese_whispers (
const std::vector<ordered_sample_pair>& edges,
std::vector<unsigned long>& labels,
const unsigned long num_iterations = 100
)
{
dlib::rand rnd;
return chinese_whispers(edges, labels, num_iterations, rnd);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CHINESE_WHISPErS_Hh_
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_
#ifdef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_
#include <vector>
#include "../rand.h"
#include "../graph_utils/ordered_sample_pair_abstract.h"
#include "../graph_utils/sample_pair_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
unsigned long chinese_whispers (
const std::vector<ordered_sample_pair>& edges,
std::vector<unsigned long>& labels,
const unsigned long num_iterations,
dlib::rand& rnd
);
/*!
requires
- is_ordered_by_index(edges) == true
ensures
- This function implements the graph clustering algorithm described in the
paper: Chinese Whispers - an Efficient Graph Clustering Algorithm and its
Application to Natural Language Processing Problems by Chris Biemann.
- Interprets edges as a directed graph. That is, it contains the edges on the
said graph and the ordered_sample_pair::distance() values define the edge
weights (larger values indicating a stronger edge connection between the
nodes). If an edge has a distance() value of infinity then it is considered
a "must link" edge.
- returns the number of clusters found.
- #labels.size() == max_index_plus_one(edges)
- for all valid i:
- #labels[i] == the cluster ID of the node with index i in the graph.
- 0 <= #labels[i] < the number of clusters found
(i.e. cluster IDs are assigned contiguously and start at 0)
- Duplicate edges are interpreted as if there had been just one edge with a
distance value equal to the sum of all the duplicate edge's distance values.
- The algorithm performs exactly num_iterations passes over the graph before
terminating.
!*/
// ----------------------------------------------------------------------------------------
unsigned long chinese_whispers (
const std::vector<sample_pair>& edges,
std::vector<unsigned long>& labels,
const unsigned long num_iterations,
dlib::rand& rnd
);
/*!
ensures
- This function is identical to the above chinese_whispers() routine except
that it operates on a vector of sample_pair objects instead of
ordered_sample_pairs. Therefore, this is simply a convenience routine. In
particular, it is implemented by transforming the given edges into
ordered_sample_pairs and then calling the chinese_whispers() routine defined
above.
!*/
// ----------------------------------------------------------------------------------------
unsigned long chinese_whispers (
const std::vector<ordered_sample_pair>& edges,
std::vector<unsigned long>& labels,
const unsigned long num_iterations = 100
);
/*!
requires
- is_ordered_by_index(edges) == true
ensures
- performs: return chinese_whispers(edges, labels, num_iterations, rnd)
where rnd is a default initialized dlib::rand object.
!*/
// ----------------------------------------------------------------------------------------
unsigned long chinese_whispers (
const std::vector<sample_pair>& edges,
std::vector<unsigned long>& labels,
const unsigned long num_iterations = 100
);
/*!
ensures
- performs: return chinese_whispers(edges, labels, num_iterations, rnd)
where rnd is a default initialized dlib::rand object.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MODULARITY_ClUSTERING__H__
#define DLIB_MODULARITY_ClUSTERING__H__
#include "modularity_clustering_abstract.h"
#include "../sparse_vector.h"
#include "../graph_utils/edge_list_graphs.h"
#include "../matrix.h"
#include "../rand.h"
namespace dlib
{
// -----------------------------------------------------------------------------------------
namespace impl
{
inline double newman_cluster_split (
dlib::rand& rnd,
const std::vector<ordered_sample_pair>& edges,
const matrix<double,0,1>& node_degrees, // k from the Newman paper
const matrix<double,0,1>& Bdiag, // diag(B) from the Newman paper
const double& edge_sum, // m from the Newman paper
matrix<double,0,1>& labels,
const double eps,
const unsigned long max_iterations
)
/*!
requires
- node_degrees.size() == max_index_plus_one(edges)
- Bdiag.size() == max_index_plus_one(edges)
- edges must be sorted according to order_by_index()
ensures
- This routine splits a graph into two subgraphs using the Newman
clustering method.
- returns the modularity obtained when the graph is split according
to the contents of #labels.
- #labels.size() == node_degrees.size()
- for all valid i: #labels(i) == -1 or +1
- if (this function returns 0) then
- all the labels are equal, i.e. the graph is not split.
!*/
{
// Scale epsilon so that it is relative to the expected value of an element of a
// unit vector of length node_degrees.size().
const double power_iter_eps = eps * std::sqrt(1.0/node_degrees.size());
// Make a random unit vector and put in labels.
labels.set_size(node_degrees.size());
for (long i = 0; i < labels.size(); ++i)
labels(i) = rnd.get_random_gaussian();
labels /= length(labels);
matrix<double,0,1> Bv, Bv_unit;
// Do the power iteration for a while.
double eig = -1;
double offset = 0;
while (eig < 0)
{
// any number larger than power_iter_eps
double iteration_change = power_iter_eps*2+1;
for (unsigned long i = 0; i < max_iterations && iteration_change > power_iter_eps; ++i)
{
sparse_matrix_vector_multiply(edges, labels, Bv);
Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees;
if (offset != 0)
{
Bv -= offset*labels;
}
const double len = length(Bv);
if (len != 0)
{
Bv_unit = Bv/len;
iteration_change = max(abs(labels-Bv_unit));
labels.swap(Bv_unit);
}
else
{
// Had a bad time, pick another random vector and try it with the
// power iteration.
for (long i = 0; i < labels.size(); ++i)
labels(i) = rnd.get_random_gaussian();
}
}
eig = dot(Bv,labels);
// we will repeat this loop if the largest eigenvalue is negative
offset = eig;
}
for (long i = 0; i < labels.size(); ++i)
{
if (labels(i) > 0)
labels(i) = 1;
else
labels(i) = -1;
}
// compute B*labels, store result in Bv.
sparse_matrix_vector_multiply(edges, labels, Bv);
Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees;
// Do some label refinement. In this step we swap labels if it
// improves the modularity score.
bool flipped_label = true;
while(flipped_label)
{
flipped_label = false;
unsigned long idx = 0;
for (long i = 0; i < labels.size(); ++i)
{
const double val = -2*labels(i);
const double increase = 4*Bdiag(i) + 2*val*Bv(i);
// if there is an increase in modularity for swapping this label
if (increase > 0)
{
labels(i) *= -1;
while (idx < edges.size() && edges[idx].index1() == (unsigned long)i)
{
const long j = edges[idx].index2();
Bv(j) += val*edges[idx].distance();
++idx;
}
Bv -= (val*node_degrees(i)/(2*edge_sum))*node_degrees;
flipped_label = true;
}
else
{
while (idx < edges.size() && edges[idx].index1() == (unsigned long)i)
{
++idx;
}
}
}
}
const double modularity = dot(Bv, labels)/(4*edge_sum);
return modularity;
}
// -------------------------------------------------------------------------------------
inline unsigned long newman_cluster_helper (
dlib::rand& rnd,
const std::vector<ordered_sample_pair>& edges,
const matrix<double,0,1>& node_degrees, // k from the Newman paper
const matrix<double,0,1>& Bdiag, // diag(B) from the Newman paper
const double& edge_sum, // m from the Newman paper
std::vector<unsigned long>& labels,
double modularity_threshold,
const double eps,
const unsigned long max_iterations
)
/*!
ensures
- returns the number of clusters the data was split into
!*/
{
matrix<double,0,1> l;
const double modularity = newman_cluster_split(rnd,edges,node_degrees,Bdiag,edge_sum,l,eps,max_iterations);
// We need to collapse the node index values down to contiguous values. So
// we use the following two vectors to contain the mappings from input index
// values to their corresponding index values in each split.
std::vector<unsigned long> left_idx_map(node_degrees.size());
std::vector<unsigned long> right_idx_map(node_degrees.size());
// figure out how many nodes went into each side of the split.
unsigned long num_left_split = 0;
unsigned long num_right_split = 0;
for (long i = 0; i < l.size(); ++i)
{
if (l(i) > 0)
{
left_idx_map[i] = num_left_split;
++num_left_split;
}
else
{
right_idx_map[i] = num_right_split;
++num_right_split;
}
}
// do a recursive split if it will improve the modularity.
if (modularity > modularity_threshold && num_left_split > 0 && num_right_split > 0)
{
// split the node_degrees and Bdiag matrices into left and right split parts
matrix<double,0,1> left_node_degrees(num_left_split);
matrix<double,0,1> right_node_degrees(num_right_split);
matrix<double,0,1> left_Bdiag(num_left_split);
matrix<double,0,1> right_Bdiag(num_right_split);
for (long i = 0; i < l.size(); ++i)
{
if (l(i) > 0)
{
left_node_degrees(left_idx_map[i]) = node_degrees(i);
left_Bdiag(left_idx_map[i]) = Bdiag(i);
}
else
{
right_node_degrees(right_idx_map[i]) = node_degrees(i);
right_Bdiag(right_idx_map[i]) = Bdiag(i);
}
}
// put the edges from one side of the split into split_edges
std::vector<ordered_sample_pair> split_edges;
modularity_threshold = 0;
for (unsigned long k = 0; k < edges.size(); ++k)
{
const unsigned long i = edges[k].index1();
const unsigned long j = edges[k].index2();
const double d = edges[k].distance();
if (l(i) > 0 && l(j) > 0)
{
split_edges.push_back(ordered_sample_pair(left_idx_map[i], left_idx_map[j], d));
modularity_threshold += d;
}
}
modularity_threshold -= sum(left_node_degrees*sum(left_node_degrees))/(2*edge_sum);
modularity_threshold /= 4*edge_sum;
unsigned long num_left_clusters;
std::vector<unsigned long> left_labels;
num_left_clusters = newman_cluster_helper(rnd,split_edges,left_node_degrees,left_Bdiag,
edge_sum,left_labels,modularity_threshold,
eps, max_iterations);
// now load the other side into split_edges and cluster it as well
split_edges.clear();
modularity_threshold = 0;
for (unsigned long k = 0; k < edges.size(); ++k)
{
const unsigned long i = edges[k].index1();
const unsigned long j = edges[k].index2();
const double d = edges[k].distance();
if (l(i) < 0 && l(j) < 0)
{
split_edges.push_back(ordered_sample_pair(right_idx_map[i], right_idx_map[j], d));
modularity_threshold += d;
}
}
modularity_threshold -= sum(right_node_degrees*sum(right_node_degrees))/(2*edge_sum);
modularity_threshold /= 4*edge_sum;
unsigned long num_right_clusters;
std::vector<unsigned long> right_labels;
num_right_clusters = newman_cluster_helper(rnd,split_edges,right_node_degrees,right_Bdiag,
edge_sum,right_labels,modularity_threshold,
eps, max_iterations);
// Now merge the labels from the two splits.
labels.resize(node_degrees.size());
for (unsigned long i = 0; i < labels.size(); ++i)
{
// if this node was in the left split
if (l(i) > 0)
{
labels[i] = left_labels[left_idx_map[i]];
}
else // if this node was in the right split
{
labels[i] = right_labels[right_idx_map[i]] + num_left_clusters;
}
}
return num_left_clusters + num_right_clusters;
}
else
{
labels.assign(node_degrees.size(),0);
return 1;
}
}
}
// ----------------------------------------------------------------------------------------
inline unsigned long newman_cluster (
const std::vector<ordered_sample_pair>& edges,
std::vector<unsigned long>& labels,
const double eps = 1e-4,
const unsigned long max_iterations = 2000
)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_ordered_by_index(edges),
"\t unsigned long newman_cluster()"
<< "\n\t Invalid inputs were given to this function"
);
labels.clear();
if (edges.size() == 0)
return 0;
const unsigned long num_nodes = max_index_plus_one(edges);
// compute the node_degrees vector, edge_sum value, and diag(B).
matrix<double,0,1> node_degrees(num_nodes);
matrix<double,0,1> Bdiag(num_nodes);
Bdiag = 0;
double edge_sum = 0;
node_degrees = 0;
for (unsigned long i = 0; i < edges.size(); ++i)
{
node_degrees(edges[i].index1()) += edges[i].distance();
edge_sum += edges[i].distance();
if (edges[i].index1() == edges[i].index2())
Bdiag(edges[i].index1()) += edges[i].distance();
}
edge_sum /= 2;
Bdiag -= squared(node_degrees)/(2*edge_sum);
dlib::rand rnd;
return impl::newman_cluster_helper(rnd,edges,node_degrees,Bdiag,edge_sum,labels,0,eps,max_iterations);
}
// ----------------------------------------------------------------------------------------
inline unsigned long newman_cluster (
const std::vector<sample_pair>& edges,
std::vector<unsigned long>& labels,
const double eps = 1e-4,
const unsigned long max_iterations = 2000
)
{
std::vector<ordered_sample_pair> oedges;
convert_unordered_to_ordered(edges, oedges);
std::sort(oedges.begin(), oedges.end(), &order_by_index<ordered_sample_pair>);
return newman_cluster(oedges, labels, eps, max_iterations);
}
// ----------------------------------------------------------------------------------------
namespace impl
{
inline std::vector<unsigned long> remap_labels (
const std::vector<unsigned long>& labels,
unsigned long& num_labels
)
/*!
ensures
- This function takes labels and produces a mapping which maps elements of
labels into the most compact range in [0, max] as possible. In particular,
there won't be any unused integers in the mapped range.
- #num_labels == the number of distinct values in labels.
- returns a vector V such that:
- V.size() == labels.size()
- max(mat(V))+1 == num_labels.
- for all valid i,j:
- if (labels[i] == labels[j]) then
- V[i] == V[j]
- else
- V[i] != V[j]
!*/
{
std::map<unsigned long, unsigned long> temp;
for (unsigned long i = 0; i < labels.size(); ++i)
{
if (temp.count(labels[i]) == 0)
{
const unsigned long next = temp.size();
temp[labels[i]] = next;
}
}
num_labels = temp.size();
std::vector<unsigned long> result(labels.size());
for (unsigned long i = 0; i < labels.size(); ++i)
{
result[i] = temp[labels[i]];
}
return result;
}
}
// ----------------------------------------------------------------------------------------
inline double modularity (
const std::vector<sample_pair>& edges,
const std::vector<unsigned long>& labels
)
{
const unsigned long num_nodes = max_index_plus_one(edges);
// make sure requires clause is not broken
DLIB_ASSERT(labels.size() == num_nodes,
"\t double modularity()"
<< "\n\t Invalid inputs were given to this function"
);
unsigned long num_labels;
const std::vector<unsigned long>& labels_ = dlib::impl::remap_labels(labels,num_labels);
std::vector<double> cluster_sums(num_labels,0);
std::vector<double> k(num_nodes,0);
double Q = 0;
double m = 0;
for (unsigned long i = 0; i < edges.size(); ++i)
{
const unsigned long n1 = edges[i].index1();
const unsigned long n2 = edges[i].index2();
k[n1] += edges[i].distance();
if (n1 != n2)
k[n2] += edges[i].distance();
if (n1 != n2)
m += edges[i].distance();
else
m += edges[i].distance()/2;
if (labels_[n1] == labels_[n2])
{
if (n1 != n2)
Q += 2*edges[i].distance();
else
Q += edges[i].distance();
}
}
if (m == 0)
return 0;
for (unsigned long i = 0; i < labels_.size(); ++i)
{
cluster_sums[labels_[i]] += k[i];
}
for (unsigned long i = 0; i < labels_.size(); ++i)
{
Q -= k[i]*cluster_sums[labels_[i]]/(2*m);
}
return 1.0/(2*m)*Q;
}
// ----------------------------------------------------------------------------------------
inline double modularity (
const std::vector<ordered_sample_pair>& edges,
const std::vector<unsigned long>& labels
)
{
const unsigned long num_nodes = max_index_plus_one(edges);
// make sure requires clause is not broken
DLIB_ASSERT(labels.size() == num_nodes,
"\t double modularity()"
<< "\n\t Invalid inputs were given to this function"
);
unsigned long num_labels;
const std::vector<unsigned long>& labels_ = dlib::impl::remap_labels(labels,num_labels);
std::vector<double> cluster_sums(num_labels,0);
std::vector<double> k(num_nodes,0);
double Q = 0;
double m = 0;
for (unsigned long i = 0; i < edges.size(); ++i)
{
const unsigned long n1 = edges[i].index1();
const unsigned long n2 = edges[i].index2();
k[n1] += edges[i].distance();
m += edges[i].distance();
if (labels_[n1] == labels_[n2])
{
Q += edges[i].distance();
}
}
if (m == 0)
return 0;
for (unsigned long i = 0; i < labels_.size(); ++i)
{
cluster_sums[labels_[i]] += k[i];
}
for (unsigned long i = 0; i < labels_.size(); ++i)
{
Q -= k[i]*cluster_sums[labels_[i]]/m;
}
return 1.0/m*Q;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_MODULARITY_ClUSTERING__H__
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_
#ifdef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_
#include <vector>
#include "../graph_utils/ordered_sample_pair_abstract.h"
#include "../graph_utils/sample_pair_abstract.h"
namespace dlib
{
// -----------------------------------------------------------------------------------------
double modularity (
const std::vector<sample_pair>& edges,
const std::vector<unsigned long>& labels
);
/*!
requires
- labels.size() == max_index_plus_one(edges)
- for all valid i:
- 0 <= edges[i].distance() < std::numeric_limits<double>::infinity()
ensures
- Interprets edges as an undirected graph. That is, it contains the edges on
the said graph and the sample_pair::distance() values define the edge weights
(larger values indicating a stronger edge connection between the nodes).
- This function returns the modularity value obtained when the given input
graph is broken into subgraphs according to the contents of labels. In
particular, we say that two nodes with indices i and j are in the same
subgraph or community if and only if labels[i] == labels[j].
- Duplicate edges are interpreted as if there had been just one edge with a
distance value equal to the sum of all the duplicate edge's distance values.
- See the paper Modularity and community structure in networks by M. E. J. Newman
for a detailed definition.
!*/
// ----------------------------------------------------------------------------------------
double modularity (
const std::vector<ordered_sample_pair>& edges,
const std::vector<unsigned long>& labels
);
/*!
requires
- labels.size() == max_index_plus_one(edges)
- for all valid i:
- 0 <= edges[i].distance() < std::numeric_limits<double>::infinity()
ensures
- Interprets edges as a directed graph. That is, it contains the edges on the
said graph and the ordered_sample_pair::distance() values define the edge
weights (larger values indicating a stronger edge connection between the
nodes). Note that, generally, modularity is only really defined for
undirected graphs. Therefore, the "directed graph" given to this function
should have symmetric edges between all nodes. The reason this function is
provided at all is because sometimes a vector of ordered_sample_pair objects
is a useful representation of an undirected graph.
- This function returns the modularity value obtained when the given input
graph is broken into subgraphs according to the contents of labels. In
particular, we say that two nodes with indices i and j are in the same
subgraph or community if and only if labels[i] == labels[j].
- Duplicate edges are interpreted as if there had been just one edge with a
distance value equal to the sum of all the duplicate edge's distance values.
- See the paper Modularity and community structure in networks by M. E. J. Newman
for a detailed definition.
!*/
// ----------------------------------------------------------------------------------------
unsigned long newman_cluster (
const std::vector<ordered_sample_pair>& edges,
std::vector<unsigned long>& labels,
const double eps = 1e-4,
const unsigned long max_iterations = 2000
);
/*!
requires
- is_ordered_by_index(edges) == true
- for all valid i:
- 0 <= edges[i].distance() < std::numeric_limits<double>::infinity()
ensures
- This function performs the clustering algorithm described in the paper
Modularity and community structure in networks by M. E. J. Newman.
- This function interprets edges as a graph and attempts to find the labeling
that maximizes modularity(edges, #labels).
- returns the number of clusters found.
- #labels.size() == max_index_plus_one(edges)
- for all valid i:
- #labels[i] == the cluster ID of the node with index i in the graph.
- 0 <= #labels[i] < the number of clusters found
(i.e. cluster IDs are assigned contiguously and start at 0)
- The main computation of the algorithm is involved in finding an eigenvector
of a certain matrix. To do this, we use the power iteration. In particular,
each time we try to find an eigenvector we will let the power iteration loop
at most max_iterations times or until it reaches an accuracy of eps.
Whichever comes first.
!*/
// ----------------------------------------------------------------------------------------
unsigned long newman_cluster (
const std::vector<sample_pair>& edges,
std::vector<unsigned long>& labels,
const double eps = 1e-4,
const unsigned long max_iterations = 2000
);
/*!
requires
- for all valid i:
- 0 <= edges[i].distance() < std::numeric_limits<double>::infinity()
ensures
- This function is identical to the above newman_cluster() routine except that
it operates on a vector of sample_pair objects instead of
ordered_sample_pairs. Therefore, this is simply a convenience routine. In
particular, it is implemented by transforming the given edges into
ordered_sample_pairs and then calling the newman_cluster() routine defined
above.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment