// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include #include #include #include #include #include #include "tester.h" namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.find_max_factor_graph_viterbi"); // ---------------------------------------------------------------------------------------- dlib::rand rnd; // ---------------------------------------------------------------------------------------- template < unsigned long O, unsigned long NS, unsigned long num_nodes > class map_problem { public: const static unsigned long order = O; const static unsigned long num_states = NS; map_problem() { data = randm(number_of_nodes(),std::pow(num_states,order+1), rnd); } unsigned long number_of_nodes ( ) const { return num_nodes; } template < typename EXP > double factor_value ( unsigned long node_id, const matrix_exp& node_states ) const { if (node_states.size() == 1) return data(node_id, node_states(0)); else if (node_states.size() == 2) return data(node_id, node_states(0) + node_states(1)*num_states); else if (node_states.size() == 3) return data(node_id, (node_states(0) + node_states(1)*num_states)*num_states + node_states(2)); else return data(node_id, ((node_states(0) + node_states(1)*num_states)*num_states + node_states(2))*num_states + node_states(3)); } matrix data; }; // ---------------------------------------------------------------------------------------- template < typename map_problem > void brute_force_find_max_factor_graph_viterbi ( const map_problem& prob, std::vector& map_assignment ) { using namespace dlib::impl; const int order = map_problem::order; const int num_states = map_problem::num_states; map_assignment.resize(prob.number_of_nodes()); double best_score = -std::numeric_limits::infinity(); matrix node_states; node_states.set_size(prob.number_of_nodes()); node_states = 0; do { double score = 0; for (unsigned long i = 0; i < prob.number_of_nodes(); ++i) { score += prob.factor_value(i, (colm(node_states,range(i,i-std::min(order,i))))); } if (score > best_score) { for (unsigned long i = 0; i < map_assignment.size(); ++i) map_assignment[i] = node_states(i); best_score = score; } } while(advance_state(node_states,num_states)); } // ---------------------------------------------------------------------------------------- template < unsigned long order, unsigned long num_states, unsigned long num_nodes > void do_test() { dlog << LINFO << "order: "<< order << " num_states: " << num_states << " num_nodes: " << num_nodes; for (int i = 0; i < 25; ++i) { print_spinner(); map_problem prob; std::vector assign, assign2; brute_force_find_max_factor_graph_viterbi(prob, assign); find_max_factor_graph_viterbi(prob, assign2); DLIB_TEST_MSG(vector_to_matrix(assign) == vector_to_matrix(assign2), trans(vector_to_matrix(assign)) << trans(vector_to_matrix(assign2)) ); } } // ---------------------------------------------------------------------------------------- class test_find_max_factor_graph_viterbi : public tester { public: test_find_max_factor_graph_viterbi ( ) : tester ("test_find_max_factor_graph_viterbi", "Runs tests on the find_max_factor_graph_viterbi routine.") {} void perform_test ( ) { do_test<1,3,0>(); do_test<1,3,1>(); do_test<1,3,2>(); do_test<0,3,2>(); do_test<1,3,8>(); do_test<2,3,7>(); do_test<3,3,8>(); do_test<4,3,8>(); do_test<0,3,8>(); do_test<4,3,1>(); do_test<4,3,0>(); do_test<0,3,0>(); do_test<1,2,8>(); do_test<2,2,7>(); do_test<3,2,8>(); do_test<0,2,8>(); do_test<1,1,8>(); do_test<2,1,8>(); do_test<3,1,8>(); do_test<0,1,8>(); } } a; }