"examples/git@developer.sourcefind.cn:modelzoo/bert4torch.git" did not exist on "58935387458ba1dee228fbd261bf9878a5d7ae32"
Commit 7745c5f9 authored by Davis King's avatar Davis King
Browse files

Modified the probabilistic() trainer adapter (and the...

Modified the probabilistic() trainer adapter (and the train_probabilistic_decision_function() routine)
so that they work with objects which have an interface compatible with std::vector rather than strictly
just std::vector objects.  For example, random_subset_selector objects can be used (as was the case in
some previous versions of dlib).

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404211
parent eba0c564
...@@ -470,19 +470,19 @@ namespace dlib ...@@ -470,19 +470,19 @@ namespace dlib
template < template <
typename trainer_type, typename trainer_type,
typename sample_type, typename sample_vector_type,
typename scalar_type, typename label_vector_type
typename alloc_type1,
typename alloc_type2
> >
const probabilistic_function<typename trainer_type::trained_function_type> const probabilistic_function<typename trainer_type::trained_function_type>
train_probabilistic_decision_function ( train_probabilistic_decision_function (
const trainer_type& trainer, const trainer_type& trainer,
const std::vector<sample_type,alloc_type1>& x, const sample_vector_type& x,
const std::vector<scalar_type,alloc_type2>& y, const label_vector_type& y,
const long folds const long folds
) )
{ {
typedef typename sample_vector_type::value_type sample_type;
typedef typename label_vector_type::value_type scalar_type;
/* /*
This function fits a sigmoid function to the output of the This function fits a sigmoid function to the output of the
...@@ -519,14 +519,14 @@ namespace dlib ...@@ -519,14 +519,14 @@ namespace dlib
const long num_neg_train_samples = num_neg - num_neg_test_samples; const long num_neg_train_samples = num_neg - num_neg_test_samples;
typename trainer_type::trained_function_type d; typename trainer_type::trained_function_type d;
std::vector<sample_type,alloc_type1> x_test, x_train; std::vector<sample_type> x_test, x_train;
std::vector<scalar_type,alloc_type2> y_test, y_train; std::vector<scalar_type> y_test, y_train;
x_test.resize (num_pos_test_samples + num_neg_test_samples); x_test.resize (num_pos_test_samples + num_neg_test_samples);
y_test.resize (num_pos_test_samples + num_neg_test_samples); y_test.resize (num_pos_test_samples + num_neg_test_samples);
x_train.resize(num_pos_train_samples + num_neg_train_samples); x_train.resize(num_pos_train_samples + num_neg_train_samples);
y_train.resize(num_pos_train_samples + num_neg_train_samples); y_train.resize(num_pos_train_samples + num_neg_train_samples);
typedef std::vector<scalar_type, alloc_type2 > dvector; typedef std::vector<scalar_type > dvector;
dvector out; dvector out;
dvector target; dvector target;
...@@ -658,14 +658,12 @@ namespace dlib ...@@ -658,14 +658,12 @@ namespace dlib
) : trainer(trainer_),folds(folds_) {} ) : trainer(trainer_),folds(folds_) {}
template < template <
typename sample_type, typename T,
typename scalar_type, typename U
typename alloc_type1,
typename alloc_type2
> >
const trained_function_type train ( const trained_function_type train (
const std::vector<sample_type,alloc_type1>& samples, const T& samples,
const std::vector<scalar_type,alloc_type2>& labels const U& labels
) const ) const
{ {
return train_probabilistic_decision_function(trainer, samples, labels, folds); return train_probabilistic_decision_function(trainer, samples, labels, folds);
......
...@@ -68,22 +68,21 @@ namespace dlib ...@@ -68,22 +68,21 @@ namespace dlib
template < template <
typename trainer_type, typename trainer_type,
typename sample_type, typename sample_vector_type,
typename scalar_type, typename label_vector_type
typename alloc_type1,
typename alloc_type2
> >
const probabilistic_function<typename trainer_type::trained_function_type> const probabilistic_function<typename trainer_type::trained_function_type>
train_probabilistic_decision_function ( train_probabilistic_decision_function (
const trainer_type& trainer, const trainer_type& trainer,
const std::vector<sample_type,alloc_type1>& x, const sample_vector_type& x,
const std::vector<scalar_type,alloc_type2>& y, const label_vector_type& y,
const long folds const long folds
); );
/*! /*!
requires requires
- 1 < folds <= x.size() - 1 < folds <= x.size()
- is_binary_classification_problem(x,y) == true - is_binary_classification_problem(x,y) == true
- x and y must be std::vector objects or types with a compatible interface.
- trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer) - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)
ensures ensures
- trains a classifier given the training samples in x and labels in y. - trains a classifier given the training samples in x and labels in y.
......
...@@ -73,6 +73,7 @@ set (tests ...@@ -73,6 +73,7 @@ set (tests
opt_qp_solver.cpp opt_qp_solver.cpp
pipe.cpp pipe.cpp
pixel.cpp pixel.cpp
probabilistic.cpp
queue.cpp queue.cpp
rand.cpp rand.cpp
read_write_mutex.cpp read_write_mutex.cpp
......
...@@ -83,6 +83,7 @@ SRC += optimization_test_functions.cpp ...@@ -83,6 +83,7 @@ SRC += optimization_test_functions.cpp
SRC += opt_qp_solver.cpp SRC += opt_qp_solver.cpp
SRC += pipe.cpp SRC += pipe.cpp
SRC += pixel.cpp SRC += pixel.cpp
SRC += probabilistic.cpp
SRC += queue.cpp SRC += queue.cpp
SRC += rand.cpp SRC += rand.cpp
SRC += read_write_mutex.cpp SRC += read_write_mutex.cpp
......
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <dlib/matrix.h>
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <vector>
#include "../stl_checked.h"
#include "../array.h"
#include "../rand.h"
#include "checkerboard.h"
#include <dlib/statistics.h>
#include "tester.h"
#include <dlib/svm_threaded.h>
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.probabilistic");
// ----------------------------------------------------------------------------------------
class test_probabilistic : public tester
{
public:
test_probabilistic (
) :
tester ("test_probabilistic",
"Runs tests on the probabilistic trainer adapter.")
{}
void perform_test (
)
{
print_spinner();
typedef double scalar_type;
typedef matrix<scalar_type,2,1> sample_type;
std::vector<sample_type> x;
std::vector<matrix<double,0,1> > x_linearized;
std::vector<scalar_type> y;
get_checkerboard_problem(x,y, 1000, 2);
random_subset_selector<sample_type> rx;
random_subset_selector<scalar_type> ry;
rx.set_max_size(x.size());
ry.set_max_size(x.size());
dlog << LINFO << "pos labels: "<< sum(vector_to_matrix(y) == +1);
dlog << LINFO << "neg labels: "<< sum(vector_to_matrix(y) == -1);
for (unsigned long i = 0; i < x.size(); ++i)
{
rx.add(x[i]);
ry.add(y[i]);
}
const scalar_type gamma = 2.0;
typedef radial_basis_kernel<sample_type> kernel_type;
krr_trainer<kernel_type> krr_trainer;
krr_trainer.use_classification_loss_for_loo_cv();
krr_trainer.set_kernel(kernel_type(gamma));
krr_trainer.set_basis(randomly_subsample(x, 100));
probabilistic_decision_function<kernel_type> df;
dlog << LINFO << "cross validation: " << cross_validate_trainer(krr_trainer, rx,ry, 4);
print_spinner();
running_stats<scalar_type> rs_pos, rs_neg;
print_spinner();
df = probabilistic(krr_trainer,3).train(x, y);
for (unsigned long i = 0; i < x.size(); ++i)
{
if (y[i] > 0)
rs_pos.add(df(x[i]));
else
rs_neg.add(df(x[i]));
}
dlog << LINFO << "rs_pos.mean(): "<< rs_pos.mean();
dlog << LINFO << "rs_neg.mean(): "<< rs_neg.mean();
DLIB_TEST_MSG(rs_pos.mean() > 0.95, rs_pos.mean());
DLIB_TEST_MSG(rs_neg.mean() < 0.05, rs_neg.mean());
rs_pos.clear();
rs_neg.clear();
print_spinner();
df = probabilistic(krr_trainer,3).train(rx, ry);
for (unsigned long i = 0; i < x.size(); ++i)
{
if (y[i] > 0)
rs_pos.add(df(x[i]));
else
rs_neg.add(df(x[i]));
}
dlog << LINFO << "rs_pos.mean(): "<< rs_pos.mean();
dlog << LINFO << "rs_neg.mean(): "<< rs_neg.mean();
DLIB_TEST_MSG(rs_pos.mean() > 0.95, rs_pos.mean());
DLIB_TEST_MSG(rs_neg.mean() < 0.05, rs_neg.mean());
rs_pos.clear();
rs_neg.clear();
}
} a;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment