Unverified Commit 50b78da5 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Fix Barlow Twins loss gradient (#2518)

* Fix Barlow Twins loss gradient

* Update reference test accuracy after fix

* Round the empirical cross-correlation matrix

Just a tiny modification that allows the values to actually reach 255 (perfect white).
parent 39852f09
...@@ -4066,8 +4066,8 @@ namespace dlib ...@@ -4066,8 +4066,8 @@ namespace dlib
resizable_tensor off_mask(ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size)); resizable_tensor off_mask(ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size));
resizable_tensor off_diag(sample_size, sample_size); resizable_tensor off_diag(sample_size, sample_size);
tt::multiply(false, off_diag, eccm, off_mask); tt::multiply(false, off_diag, eccm, off_mask);
tt::gemm(1, grad_input_a, lambda, zb_norm, false, off_diag, false); tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false);
tt::gemm(1, grad_input_b, lambda, za_norm, false, off_diag, false); tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false);
// Compute the batch norm gradients, g and b grads are not used // Compute the batch norm gradients, g and b grads are not used
resizable_tensor g_grad, b_grad; resizable_tensor g_grad, b_grad;
......
...@@ -277,7 +277,7 @@ try ...@@ -277,7 +277,7 @@ try
// visualize it. // visualize it.
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false); tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
eccm /= batch_size; eccm /= batch_size;
win.set_image(abs(mat(eccm)) * 255); win.set_image(round(abs(mat(eccm)) * 255));
win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls())); win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls()));
} }
} }
...@@ -304,12 +304,14 @@ try ...@@ -304,12 +304,14 @@ try
auto cross_validation_score = [&](const double c) auto cross_validation_score = [&](const double c)
{ {
svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer; svm_multiclass_linear_trainer<linear_kernel<matrix<float, 0, 1>>, unsigned long> trainer;
trainer.set_num_threads(std::thread::hardware_concurrency());
trainer.set_c(c); trainer.set_c(c);
trainer.set_epsilon(0.01);
trainer.set_max_iterations(100);
trainer.set_num_threads(std::thread::hardware_concurrency());
cout << "C: " << c << endl; cout << "C: " << c << endl;
const auto cm = cross_validate_multiclass_trainer(trainer, features, training_labels, 3); const auto cm = cross_validate_multiclass_trainer(trainer, features, training_labels, 3);
const double accuracy = sum(diag(cm)) / sum(cm); const double accuracy = sum(diag(cm)) / sum(cm);
cout << "cross validation accuracy: " << accuracy << endl;; cout << "cross validation accuracy: " << accuracy << endl;
cout << "confusion matrix:\n " << cm << endl; cout << "confusion matrix:\n " << cm << endl;
return accuracy; return accuracy;
}; };
...@@ -345,7 +347,7 @@ try ...@@ -345,7 +347,7 @@ try
cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl; cout << " error rate: " << num_wrong / static_cast<double>(num_right + num_wrong) << endl;
}; };
// We should get a training accuracy of around 93% and a testing accuracy of around 88%. // We should get a training accuracy of around 93% and a testing accuracy of around 89%.
cout << "\ntraining accuracy" << endl; cout << "\ntraining accuracy" << endl;
compute_accuracy(features, training_labels); compute_accuracy(features, training_labels);
cout << "\ntesting accuracy" << endl; cout << "\ntesting accuracy" << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment