"tools/vscode:/vscode.git/clone" did not exist on "309902b81a98bf4461baef609f818243736d5cc4"
Commit 66d5a906 authored by Davis King's avatar Davis King
Browse files

Fully setup the functional python interface to the sequence segmenter tool.

Need to add documentation next.
parent bb0f764c
......@@ -12,7 +12,7 @@ using namespace dlib;
using namespace std;
using namespace boost::python;
typedef matrix<double,0,1> sample_type;
typedef matrix<double,0,1> dense_vect;
typedef std::vector<std::pair<unsigned long,double> > sparse_vect;
typedef std::vector<std::pair<unsigned long, unsigned long> > ranges;
......@@ -33,7 +33,7 @@ public:
unsigned long _window_size;
segmenter_feature_extractor(
) : _num_features(0), _window_size(0) {}
) : _num_features(1), _window_size(1) {}
segmenter_feature_extractor(
unsigned long _num_features_,
......@@ -49,7 +49,7 @@ public:
template <typename feature_setter>
void get_features (
feature_setter& set_feature,
const std::vector<sample_type>& x,
const std::vector<dense_vect>& x,
unsigned long position
) const
{
......@@ -88,17 +88,54 @@ public:
struct segmenter_type
{
segmenter_type() : mode(0)
/*!
WHAT THIS OBJECT REPRESENTS
This the object that python will use directly to represent a
sequence_segmenter. All it does is contain all the possible template
instantiations of a sequence_segmenter and invoke the right one depending on
the mode variable.
!*/
segmenter_type() : mode(-1)
{ }
ranges segment_sequence (
const std::vector<sample_type>& x
ranges segment_sequence_dense (
const std::vector<dense_vect>& x
) const
{
return ranges();
switch (mode)
{
case 0: return segmenter0(x);
case 1: return segmenter1(x);
case 2: return segmenter2(x);
case 3: return segmenter3(x);
case 4: return segmenter4(x);
case 5: return segmenter5(x);
case 6: return segmenter6(x);
case 7: return segmenter7(x);
default: throw dlib::error("Invalid mode");
}
}
const matrix<double,0,1>& get_weights()
ranges segment_sequence_sparse (
const std::vector<sparse_vect>& x
) const
{
switch (mode)
{
case 8: return segmenter8(x);
case 9: return segmenter9(x);
case 10: return segmenter10(x);
case 11: return segmenter11(x);
case 12: return segmenter12(x);
case 13: return segmenter13(x);
case 14: return segmenter14(x);
case 15: return segmenter15(x);
default: throw dlib::error("Invalid mode");
}
}
const matrix<double,0,1> get_weights()
{
switch(mode)
{
......@@ -110,6 +147,17 @@ struct segmenter_type
case 5: return segmenter5.get_weights();
case 6: return segmenter6.get_weights();
case 7: return segmenter7.get_weights();
case 8: return segmenter8.get_weights();
case 9: return segmenter9.get_weights();
case 10: return segmenter10.get_weights();
case 11: return segmenter11.get_weights();
case 12: return segmenter12.get_weights();
case 13: return segmenter13.get_weights();
case 14: return segmenter14.get_weights();
case 15: return segmenter15.get_weights();
default: throw dlib::error("Invalid mode");
}
}
......@@ -126,6 +174,16 @@ struct segmenter_type
case 5: serialize(item.segmenter5, out); break;
case 6: serialize(item.segmenter6, out); break;
case 7: serialize(item.segmenter7, out); break;
case 8: serialize(item.segmenter8, out); break;
case 9: serialize(item.segmenter9, out); break;
case 10: serialize(item.segmenter10, out); break;
case 11: serialize(item.segmenter11, out); break;
case 12: serialize(item.segmenter12, out); break;
case 13: serialize(item.segmenter13, out); break;
case 14: serialize(item.segmenter14, out); break;
case 15: serialize(item.segmenter15, out); break;
default: throw dlib::error("Invalid mode");
}
}
friend void deserialize (segmenter_type& item, std::istream& in)
......@@ -141,19 +199,29 @@ struct segmenter_type
case 5: deserialize(item.segmenter5, in); break;
case 6: deserialize(item.segmenter6, in); break;
case 7: deserialize(item.segmenter7, in); break;
case 8: deserialize(item.segmenter8, in); break;
case 9: deserialize(item.segmenter9, in); break;
case 10: deserialize(item.segmenter10, in); break;
case 11: deserialize(item.segmenter11, in); break;
case 12: deserialize(item.segmenter12, in); break;
case 13: deserialize(item.segmenter13, in); break;
case 14: deserialize(item.segmenter14, in); break;
case 15: deserialize(item.segmenter15, in); break;
default: throw dlib::error("Invalid mode");
}
}
int mode;
typedef segmenter_feature_extractor<sample_type, true, true, true> fe0;
typedef segmenter_feature_extractor<sample_type, true, true, false> fe1;
typedef segmenter_feature_extractor<sample_type, true, false,true> fe2;
typedef segmenter_feature_extractor<sample_type, true, false,false> fe3;
typedef segmenter_feature_extractor<sample_type, false,true, true> fe4;
typedef segmenter_feature_extractor<sample_type, false,true, false> fe5;
typedef segmenter_feature_extractor<sample_type, false,false,true> fe6;
typedef segmenter_feature_extractor<sample_type, false,false,false> fe7;
typedef segmenter_feature_extractor<dense_vect, false,false,false> fe0;
typedef segmenter_feature_extractor<dense_vect, false,false,true> fe1;
typedef segmenter_feature_extractor<dense_vect, false,true, false> fe2;
typedef segmenter_feature_extractor<dense_vect, false,true, true> fe3;
typedef segmenter_feature_extractor<dense_vect, true, false,false> fe4;
typedef segmenter_feature_extractor<dense_vect, true, false,true> fe5;
typedef segmenter_feature_extractor<dense_vect, true, true, false> fe6;
typedef segmenter_feature_extractor<dense_vect, true, true, true> fe7;
sequence_segmenter<fe0> segmenter0;
sequence_segmenter<fe1> segmenter1;
sequence_segmenter<fe2> segmenter2;
......@@ -163,14 +231,14 @@ struct segmenter_type
sequence_segmenter<fe6> segmenter6;
sequence_segmenter<fe7> segmenter7;
typedef segmenter_feature_extractor<sparse_vect, true, true, true> fe8;
typedef segmenter_feature_extractor<sparse_vect, true, true, false> fe9;
typedef segmenter_feature_extractor<sparse_vect, true, false,true> fe10;
typedef segmenter_feature_extractor<sparse_vect, true, false,false> fe11;
typedef segmenter_feature_extractor<sparse_vect, false,true, true> fe12;
typedef segmenter_feature_extractor<sparse_vect, false,true, false> fe13;
typedef segmenter_feature_extractor<sparse_vect, false,false,true> fe14;
typedef segmenter_feature_extractor<sparse_vect, false,false,false> fe15;
typedef segmenter_feature_extractor<sparse_vect, false,false,false> fe8;
typedef segmenter_feature_extractor<sparse_vect, false,false,true> fe9;
typedef segmenter_feature_extractor<sparse_vect, false,true, false> fe10;
typedef segmenter_feature_extractor<sparse_vect, false,true, true> fe11;
typedef segmenter_feature_extractor<sparse_vect, true, false,false> fe12;
typedef segmenter_feature_extractor<sparse_vect, true, false,true> fe13;
typedef segmenter_feature_extractor<sparse_vect, true, true, false> fe14;
typedef segmenter_feature_extractor<sparse_vect, true, true, true> fe15;
sequence_segmenter<fe8> segmenter8;
sequence_segmenter<fe9> segmenter9;
sequence_segmenter<fe10> segmenter10;
......@@ -195,6 +263,7 @@ struct segmenter_params
num_threads = 4;
epsilon = 0.1;
max_cache_size = 40;
be_verbose = false;
C = 100;
}
......@@ -209,11 +278,77 @@ struct segmenter_params
double C;
};
string segmenter_params__str__(const segmenter_params& p)
{
ostringstream sout;
if (p.use_BIO_model)
sout << "BIO,";
else
sout << "BILOU,";
if (p.use_high_order_features)
sout << "highFeats,";
else
sout << "lowFeats,";
if (p.allow_negative_weights)
sout << "signed,";
else
sout << "non-negative,";
sout << "win="<<p.window_size << ",";
sout << "threads="<<p.num_threads << ",";
sout << "eps="<<p.epsilon << ",";
sout << "cache="<<p.max_cache_size << ",";
if (p.be_verbose)
sout << "verbose,";
else
sout << "non-verbose,";
sout << "C="<<p.C;
return trim(sout.str());
}
string segmenter_params__repr__(const segmenter_params& p)
{
ostringstream sout;
sout << "<";
sout << segmenter_params__str__(p);
sout << ">";
return sout.str();
}
void serialize ( const segmenter_params& item, std::ostream& out)
{
serialize(item.use_BIO_model, out);
serialize(item.use_high_order_features, out);
serialize(item.allow_negative_weights, out);
serialize(item.window_size, out);
serialize(item.num_threads, out);
serialize(item.epsilon, out);
serialize(item.max_cache_size, out);
serialize(item.be_verbose, out);
serialize(item.C, out);
}
void deserialize (segmenter_params& item, std::istream& in)
{
deserialize(item.use_BIO_model, in);
deserialize(item.use_high_order_features, in);
deserialize(item.allow_negative_weights, in);
deserialize(item.window_size, in);
deserialize(item.num_threads, in);
deserialize(item.epsilon, in);
deserialize(item.max_cache_size, in);
deserialize(item.be_verbose, in);
deserialize(item.C, in);
}
// ----------------------------------------------------------------------------------------
template <typename T>
void configure_trainer (
const std::vector<std::vector<sample_type> >& samples,
const std::vector<std::vector<dense_vect> >& samples,
structural_sequence_segmentation_trainer<T>& trainer,
const segmenter_params& params
)
......@@ -233,8 +368,35 @@ void configure_trainer (
// ----------------------------------------------------------------------------------------
template <typename T>
void configure_trainer (
const std::vector<std::vector<sparse_vect> >& samples,
structural_sequence_segmentation_trainer<T>& trainer,
const segmenter_params& params
)
{
pyassert(samples.size() != 0, "Invalid arguments. You must give some training sequences.");
pyassert(samples[0].size() != 0, "Invalid arguments. You can't have zero length training sequences.");
unsigned long dims = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
{
dims = std::max(dims, max_index_plus_one(samples[i]));
}
trainer = structural_sequence_segmentation_trainer<T>(T(dims, params.window_size));
trainer.set_num_threads(params.num_threads);
trainer.set_epsilon(params.epsilon);
trainer.set_max_cache_size(params.max_cache_size);
trainer.set_c(params.C);
if (params.be_verbose)
trainer.be_verbose();
}
// ----------------------------------------------------------------------------------------
segmenter_type train_dense (
const std::vector<std::vector<sample_type> >& samples,
const std::vector<std::vector<dense_vect> >& samples,
const std::vector<ranges>& segments,
segmenter_params params
)
......@@ -255,6 +417,7 @@ segmenter_type train_dense (
else
mode = mode*2;
segmenter_type res;
res.mode = mode;
switch(mode)
......@@ -291,6 +454,76 @@ segmenter_type train_dense (
configure_trainer(samples, trainer, params);
res.segmenter7 = trainer.train(samples, segments);
} break;
default: throw dlib::error("Invalid mode");
}
return res;
}
// ----------------------------------------------------------------------------------------
segmenter_type train_sparse (
const std::vector<std::vector<sparse_vect> >& samples,
const std::vector<ranges>& segments,
segmenter_params params
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
int mode = 0;
if (params.use_BIO_model)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.use_high_order_features)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.allow_negative_weights)
mode = mode*2 + 1;
else
mode = mode*2;
mode += 8;
segmenter_type res;
res.mode = mode;
switch(mode)
{
case 8: { structural_sequence_segmentation_trainer<segmenter_type::fe8> trainer;
configure_trainer(samples, trainer, params);
res.segmenter8 = trainer.train(samples, segments);
} break;
case 9: { structural_sequence_segmentation_trainer<segmenter_type::fe9> trainer;
configure_trainer(samples, trainer, params);
res.segmenter9 = trainer.train(samples, segments);
} break;
case 10: { structural_sequence_segmentation_trainer<segmenter_type::fe10> trainer;
configure_trainer(samples, trainer, params);
res.segmenter10 = trainer.train(samples, segments);
} break;
case 11: { structural_sequence_segmentation_trainer<segmenter_type::fe11> trainer;
configure_trainer(samples, trainer, params);
res.segmenter11 = trainer.train(samples, segments);
} break;
case 12: { structural_sequence_segmentation_trainer<segmenter_type::fe12> trainer;
configure_trainer(samples, trainer, params);
res.segmenter12 = trainer.train(samples, segments);
} break;
case 13: { structural_sequence_segmentation_trainer<segmenter_type::fe13> trainer;
configure_trainer(samples, trainer, params);
res.segmenter13 = trainer.train(samples, segments);
} break;
case 14: { structural_sequence_segmentation_trainer<segmenter_type::fe14> trainer;
configure_trainer(samples, trainer, params);
res.segmenter14 = trainer.train(samples, segments);
} break;
case 15: { structural_sequence_segmentation_trainer<segmenter_type::fe15> trainer;
configure_trainer(samples, trainer, params);
res.segmenter15 = trainer.train(samples, segments);
} break;
default: throw dlib::error("Invalid mode");
}
......@@ -304,21 +537,27 @@ void bind_sequence_segmenter()
class_<segmenter_params>("segmenter_params",
"This class is used to define all the optional parameters to the \n\
train_sequence_segmenter() routine. ")
.add_property("use_BIO_model", &segmenter_params::use_BIO_model)
.add_property("use_high_order_features", &segmenter_params::use_high_order_features)
.add_property("allow_negative_weights", &segmenter_params::allow_negative_weights)
.add_property("window_size", &segmenter_params::window_size)
.add_property("num_threads", &segmenter_params::num_threads)
.add_property("epsilon", &segmenter_params::epsilon)
.add_property("max_cache_size", &segmenter_params::max_cache_size)
.add_property("C", &segmenter_params::C);
.def_readwrite("use_BIO_model", &segmenter_params::use_BIO_model)
.def_readwrite("use_high_order_features", &segmenter_params::use_high_order_features)
.def_readwrite("allow_negative_weights", &segmenter_params::allow_negative_weights)
.def_readwrite("window_size", &segmenter_params::window_size)
.def_readwrite("num_threads", &segmenter_params::num_threads)
.def_readwrite("epsilon", &segmenter_params::epsilon)
.def_readwrite("max_cache_size", &segmenter_params::max_cache_size)
.def_readwrite("C", &segmenter_params::C, "SVM C parameter")
.def("__repr__",&segmenter_params__repr__)
.def("__str__",&segmenter_params__str__)
.def_pickle(serialize_pickle<segmenter_params>());
class_<segmenter_type> ("segmenter_type")
.def("segment_sequence", &segmenter_type::segment_sequence)
.def("segment_sequence", &segmenter_type::segment_sequence_dense)
.def("segment_sequence", &segmenter_type::segment_sequence_sparse)
.def_readonly("weights", &segmenter_type::get_weights)
.def_pickle(serialize_pickle<segmenter_type>());
using boost::python::arg;
def("train_sequence_segmenter", train_dense, (arg("samples"), arg("segments"), arg("params")=segmenter_params()));
def("train_sequence_segmenter", train_sparse, (arg("samples"), arg("segments"), arg("params")=segmenter_params()));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment