Unverified Commit 5091e9c8 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Replace sgd-based fc classifier with svm_multiclass_linear_trainer (#2452)



* Replace fc classifier with svm_multiclass_linear_trainer

* Mention about find_max_global()
Co-authored-by: default avatarDavis E. King <davis@dlib.net>

* Use double instead of float for extracted features
Co-authored-by: default avatarDavis E. King <davis@dlib.net>

* fix compilation with double features

* Revert "fix compilation with double features"

This reverts commit 76ebab4b91ed31d2332206fe8de092043c0f687f.

* Revert "Use double instead of float for extracted features"

This reverts commit 9a50809ebf0f420e72a3c2b4b856dc1a71b9c6b3.

* Find best C using global optimization
Co-authored-by: default avatarDavis E. King <davis@dlib.net>
parent f77189db
...@@ -36,10 +36,12 @@ ...@@ -36,10 +36,12 @@
a max pooling layer afterwards, like the paper does. a max pooling layer afterwards, like the paper does.
*/ */
#include <dlib/dnn.h>
#include <dlib/data_io.h>
#include <dlib/cmd_line_parser.h> #include <dlib/cmd_line_parser.h>
#include <dlib/data_io.h>
#include <dlib/dnn.h>
#include <dlib/global_optimization.h>
#include <dlib/gui_widgets.h> #include <dlib/gui_widgets.h>
#include <dlib/svm_threaded.h>
using namespace std; using namespace std;
using namespace dlib; using namespace dlib;
...@@ -82,14 +84,12 @@ namespace resnet50 ...@@ -82,14 +84,12 @@ namespace resnet50
// This model namespace contains the definitions for: // This model namespace contains the definitions for:
// - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair. // - SSL model using the Barlow Twins loss, a projector head and an input_rgb_image_pair.
// - Classifier model using the loss_multiclass_log, a fc layer and an input_rgb_image. // - A feature extractor model using the loss_metric (to get the outputs) and an input_rgb_image.
namespace model namespace model
{ {
template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>; template <typename SUBNET> using projector = fc<128, relu<bn_fc<fc<512, SUBNET>>>>;
template <typename SUBNET> using classifier = fc<10, SUBNET>;
using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>; using train = loss_barlow_twins<projector<resnet50::def<bn_con>::backbone<input_rgb_image_pair>>>;
using infer = loss_multiclass_log<classifier<resnet50::def<affine>::backbone<input_rgb_image>>>; using feats = loss_metric<resnet50::def<affine>::backbone<input_rgb_image>>;
} }
rectangle make_random_cropping_rect( rectangle make_random_cropping_rect(
...@@ -288,73 +288,65 @@ try ...@@ -288,73 +288,65 @@ try
serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net); serialize("resnet50_self_supervised_cifar_10.net") << layer<5>(net);
} }
// To check the quality of the learned feature representations, we will train a linear // Now, we initialize the feature extractor model with the backbone we have just learned.
// classififer on top of the frozen backbone. model::feats fnet(layer<5>(net));
model::infer inet; // And we will generate all the features for the training set to train a multiclass SVM
// Assign the network, without the projector, which is only used for the self-supervised // classifier.
// training. std::vector<matrix<float, 0, 1>> features;
layer<2>(inet) = layer<5>(net); cout << "Extracting features for linear classifier..." << endl;
// Freeze the backbone features = fnet(training_images, 4 * batch_size);
set_all_learning_rate_multipliers(layer<2>(inet), 0);
// Train the network // Find the most appropriate C setting using find_max_global.
auto cross_validation_score = [&](const double c)
{ {
dnn_trainer<model::infer, adam> trainer(inet, adam(1e-6, 0.9, 0.999), gpus); svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
// Since this model doesn't train with pairs, just single images, we can increase trainer.set_num_threads(std::thread::hardware_concurrency());
// the batch size. trainer.set_c(c);
trainer.set_mini_batch_size(2 * batch_size); cout << "C: " << c << endl;
trainer.set_learning_rate(learning_rate); const auto cm = cross_validate_multiclass_trainer(trainer, features, training_labels, 3);
trainer.set_min_learning_rate(min_learning_rate); const double accuracy = sum(diag(cm)) / sum(cm);
trainer.set_iterations_without_progress_threshold(5000); cout << "cross validation accuracy: " << accuracy << endl;;
trainer.set_synchronization_file("cifar_10_sync"); cout << "confusion matrix:\n " << cm << endl;
trainer.be_verbose(); return accuracy;
cout << trainer << endl; };
const auto result = find_max_global(cross_validation_score, 1e-4, 10000, max_function_calls(50));
cout << "Best C: " << result.x(0) << endl;
std::vector<matrix<rgb_pixel>> images; // Proceed to train the SVM classifier with the best C.
std::vector<unsigned long> labels; svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate()) trainer.set_num_threads(std::thread::hardware_concurrency());
{ trainer.set_c(result.x(0));
images.clear(); cout << "Training Multiclass SVM..." << endl;
labels.clear(); const auto df = trainer.train(features, training_labels);
while (images.size() < trainer.get_mini_batch_size()) serialize("multiclass_svm_cifar_10.dat") << df;
{
const auto idx = rnd.get_random_32bit_number() % training_images.size();
images.push_back(augment(training_images[idx], false, rnd));
labels.push_back(training_labels[idx]);
}
trainer.train_one_step(images, labels);
}
trainer.get_net();
inet.clean();
serialize("resnet50_cifar_10.dnn") << inet;
}
// Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images. // Finally, we can compute the accuracy of the model on the CIFAR-10 train and test images.
auto compute_accuracy = [&inet, batch_size]( auto compute_accuracy = [&fnet, &df, batch_size](
const std::vector<matrix<rgb_pixel>>& images, const std::vector<matrix<float, 0, 1>>& samples,
const std::vector<unsigned long>& labels const std::vector<unsigned long>& labels
) )
{ {
size_t num_right = 0; size_t num_right = 0;
size_t num_wrong = 0; size_t num_wrong = 0;
const auto preds = inet(images, batch_size * 2);
for (size_t i = 0; i < labels.size(); ++i) for (size_t i = 0; i < labels.size(); ++i)
{ {
if (labels[i] == preds[i]) if (labels[i] == df(samples[i]))
++num_right; ++num_right;
else else
++num_wrong; ++num_wrong;
} }
cout << "num right: " << num_right << endl; cout << " num right: " << num_right << endl;
cout << "num wrong: " << num_wrong << endl; cout << " num wrong: " << num_wrong << endl;
cout << "accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl; cout << " accuracy: " << num_right / static_cast<double>(num_right + num_wrong) << endl;
cout << "error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl; cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
}; };
// If everything works as expected, we should get accuracies that are between 87% and 90%. // We should get a training accuracy of around 93% and a testing accuracy of around 88%.
cout << "training accuracy" << endl; cout << "\ntraining accuracy" << endl;
compute_accuracy(training_images, training_labels); compute_accuracy(features, training_labels);
cout << "\ntesting accuracy" << endl; cout << "\ntesting accuracy" << endl;
compute_accuracy(testing_images, testing_labels); features = fnet(testing_images, 4 * batch_size);
compute_accuracy(features, testing_labels);
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
catch (const exception& e) catch (const exception& e)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment