"vscode:/vscode.git/clone" did not exist on "fa9c82d339626d293ea0286d92a776294300e834"
Commit 1c269270 authored by Davis King's avatar Davis King
Browse files

Added testing and cross validation routines for the python sequence segmenter interface.

parent a4590776
......@@ -355,6 +355,9 @@ void configure_trainer (
{
pyassert(samples.size() != 0, "Invalid arguments. You must give some training sequences.");
pyassert(samples[0].size() != 0, "Invalid arguments. You can't have zero length training sequences.");
pyassert(params.window_size != 0, "Invalid window_size parameter, it must be > 0.");
pyassert(params.epsilon > 0, "Invalid epsilon parameter, it must be > 0.");
pyassert(params.C > 0, "Invalid C parameter, it must be > 0.");
const long dims = samples[0][0].size();
trainer = structural_sequence_segmentation_trainer<T>(T(dims, params.window_size));
......@@ -532,11 +535,252 @@ segmenter_type train_sparse (
// ----------------------------------------------------------------------------------------
struct segmenter_test
{
double precision;
double recall;
double f1;
};
void serialize(const segmenter_test& item, std::ostream& out)
{
serialize(item.precision, out);
serialize(item.recall, out);
serialize(item.f1, out);
}
void deserialize(segmenter_test& item, std::istream& in)
{
deserialize(item.precision, in);
deserialize(item.recall, in);
deserialize(item.f1, in);
}
std::string segmenter_test__str__(const segmenter_test& item)
{
std::ostringstream sout;
sout << "precision: "<< item.precision << " recall: "<< item.recall << " f1-score: " << item.f1;
return sout.str();
}
std::string segmenter_test__repr__(const segmenter_test& item) { return "< " + segmenter_test__str__(item) + " >";}
// ----------------------------------------------------------------------------------------
const segmenter_test test_sequence_segmenter1 (
const segmenter_type& segmenter,
const std::vector<std::vector<dense_vect> >& samples,
const std::vector<ranges>& segments
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
matrix<double,1,3> res;
switch(segmenter.mode)
{
case 0: res = test_sequence_segmenter(segmenter.segmenter0, samples, segments); break;
case 1: res = test_sequence_segmenter(segmenter.segmenter1, samples, segments); break;
case 2: res = test_sequence_segmenter(segmenter.segmenter2, samples, segments); break;
case 3: res = test_sequence_segmenter(segmenter.segmenter3, samples, segments); break;
case 4: res = test_sequence_segmenter(segmenter.segmenter4, samples, segments); break;
case 5: res = test_sequence_segmenter(segmenter.segmenter5, samples, segments); break;
case 6: res = test_sequence_segmenter(segmenter.segmenter6, samples, segments); break;
case 7: res = test_sequence_segmenter(segmenter.segmenter7, samples, segments); break;
default: throw dlib::error("Invalid mode");
}
segmenter_test temp;
temp.precision = res(0);
temp.recall = res(1);
temp.f1 = res(2);
return temp;
}
const segmenter_test test_sequence_segmenter2 (
const segmenter_type& segmenter,
const std::vector<std::vector<sparse_vect> >& samples,
const std::vector<ranges>& segments
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
matrix<double,1,3> res;
switch(segmenter.mode)
{
case 8: res = test_sequence_segmenter(segmenter.segmenter8, samples, segments); break;
case 9: res = test_sequence_segmenter(segmenter.segmenter9, samples, segments); break;
case 10: res = test_sequence_segmenter(segmenter.segmenter10, samples, segments); break;
case 11: res = test_sequence_segmenter(segmenter.segmenter11, samples, segments); break;
case 12: res = test_sequence_segmenter(segmenter.segmenter12, samples, segments); break;
case 13: res = test_sequence_segmenter(segmenter.segmenter13, samples, segments); break;
case 14: res = test_sequence_segmenter(segmenter.segmenter14, samples, segments); break;
case 15: res = test_sequence_segmenter(segmenter.segmenter15, samples, segments); break;
default: throw dlib::error("Invalid mode");
}
segmenter_test temp;
temp.precision = res(0);
temp.recall = res(1);
temp.f1 = res(2);
return temp;
}
// ----------------------------------------------------------------------------------------
const segmenter_test cross_validate_sequence_segmenter1 (
const std::vector<std::vector<dense_vect> >& samples,
const std::vector<ranges>& segments,
long folds,
segmenter_params params
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
pyassert(1 < folds && folds <= static_cast<long>(samples.size()), "folds argument is outside the valid range.");
matrix<double,1,3> res;
int mode = 0;
if (params.use_BIO_model)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.use_high_order_features)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.allow_negative_weights)
mode = mode*2 + 1;
else
mode = mode*2;
switch(mode)
{
case 0: { structural_sequence_segmentation_trainer<segmenter_type::fe0> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 1: { structural_sequence_segmentation_trainer<segmenter_type::fe1> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 2: { structural_sequence_segmentation_trainer<segmenter_type::fe2> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 3: { structural_sequence_segmentation_trainer<segmenter_type::fe3> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 4: { structural_sequence_segmentation_trainer<segmenter_type::fe4> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 5: { structural_sequence_segmentation_trainer<segmenter_type::fe5> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 6: { structural_sequence_segmentation_trainer<segmenter_type::fe6> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 7: { structural_sequence_segmentation_trainer<segmenter_type::fe7> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
default: throw dlib::error("Invalid mode");
}
segmenter_test temp;
temp.precision = res(0);
temp.recall = res(1);
temp.f1 = res(2);
return temp;
}
const segmenter_test cross_validate_sequence_segmenter2 (
const std::vector<std::vector<sparse_vect> >& samples,
const std::vector<ranges>& segments,
long folds,
segmenter_params params
)
{
pyassert(is_sequence_segmentation_problem(samples, segments), "Invalid inputs");
pyassert(1 < folds && folds <= static_cast<long>(samples.size()), "folds argument is outside the valid range.");
matrix<double,1,3> res;
int mode = 0;
if (params.use_BIO_model)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.use_high_order_features)
mode = mode*2 + 1;
else
mode = mode*2;
if (params.allow_negative_weights)
mode = mode*2 + 1;
else
mode = mode*2;
mode += 8;
switch(mode)
{
case 8: { structural_sequence_segmentation_trainer<segmenter_type::fe8> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 9: { structural_sequence_segmentation_trainer<segmenter_type::fe9> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 10: { structural_sequence_segmentation_trainer<segmenter_type::fe10> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 11: { structural_sequence_segmentation_trainer<segmenter_type::fe11> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 12: { structural_sequence_segmentation_trainer<segmenter_type::fe12> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 13: { structural_sequence_segmentation_trainer<segmenter_type::fe13> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 14: { structural_sequence_segmentation_trainer<segmenter_type::fe14> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
case 15: { structural_sequence_segmentation_trainer<segmenter_type::fe15> trainer;
configure_trainer(samples, trainer, params);
res = cross_validate_sequence_segmenter(trainer, samples, segments, folds);
} break;
default: throw dlib::error("Invalid mode");
}
segmenter_test temp;
temp.precision = res(0);
temp.recall = res(1);
temp.f1 = res(2);
return temp;
}
// ----------------------------------------------------------------------------------------
void bind_sequence_segmenter()
{
class_<segmenter_params>("segmenter_params",
"This class is used to define all the optional parameters to the \n\
train_sequence_segmenter() routine. ")
train_sequence_segmenter() and cross_validate_sequence_segmenter() routines. ")
.def_readwrite("use_BIO_model", &segmenter_params::use_BIO_model)
.def_readwrite("use_high_order_features", &segmenter_params::use_high_order_features)
.def_readwrite("allow_negative_weights", &segmenter_params::allow_negative_weights)
......@@ -545,6 +789,7 @@ train_sequence_segmenter() routine. ")
.def_readwrite("epsilon", &segmenter_params::epsilon)
.def_readwrite("max_cache_size", &segmenter_params::max_cache_size)
.def_readwrite("C", &segmenter_params::C, "SVM C parameter")
.def_readwrite("be_verbose", &segmenter_params::be_verbose)
.def("__repr__",&segmenter_params__repr__)
.def("__str__",&segmenter_params__str__)
.def_pickle(serialize_pickle<segmenter_params>());
......@@ -555,9 +800,26 @@ train_sequence_segmenter() routine. ")
.def_readonly("weights", &segmenter_type::get_weights)
.def_pickle(serialize_pickle<segmenter_type>());
class_<segmenter_test> ("segmenter_test")
.def_readwrite("precision", &segmenter_test::precision)
.def_readwrite("recall", &segmenter_test::recall)
.def_readwrite("f1", &segmenter_test::f1)
.def("__repr__",&segmenter_test__repr__)
.def("__str__",&segmenter_test__str__)
.def_pickle(serialize_pickle<segmenter_test>());
using boost::python::arg;
def("train_sequence_segmenter", train_dense, (arg("samples"), arg("segments"), arg("params")=segmenter_params()));
def("train_sequence_segmenter", train_sparse, (arg("samples"), arg("segments"), arg("params")=segmenter_params()));
def("test_sequence_segmenter", test_sequence_segmenter1);
def("test_sequence_segmenter", test_sequence_segmenter2);
def("cross_validate_sequence_segmenter", cross_validate_sequence_segmenter1,
(arg("samples"), arg("segments"), arg("folds"), arg("params")=segmenter_params()));
def("cross_validate_sequence_segmenter", cross_validate_sequence_segmenter2,
(arg("samples"), arg("segments"), arg("folds"), arg("params")=segmenter_params()));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment