"...text-generation-inference.git" did not exist on "0759ec495e15a865d2a59befc2b796b5acc09b50"
Commit d175c350 authored by Juha Reunanen's avatar Juha Reunanen Committed by Davis E. King
Browse files

Instance segmentation (#1918)

* Add instance segmentation example - first version of training code

* Add MMOD options; get rid of the cache approach, and instead load all MMOD rects upfront

* Improve console output

* Set filter count

* Minor tweaking

* Inference - first version, at least compiles!

* Ignore overlapped boxes

* Ignore even small instances

* Set overlaps_ignore

* Add TODO remarks

* Revert "Set overlaps_ignore"

This reverts commit 65adeff1f89af62b10c691e7aa86c04fc358d03e.

* Set result size

* Set label image size

* Take ignore-color into account

* Fix the cropping rect's aspect ratio; also slightly expand the rect

* Draw the largest findings last

* Improve masking of the current instance

* Add some perturbation to the inputs

* Simplify ground-truth reading; fix random cropping

* Read even class labels

* Tweak default minibatch size

* Learn only one class

* Really train only instances of the selected class

* Remove outdated TODO remark

* Automatically skip images with no detections

* Print to console what was found

* Fix class index problem

* Fix indentation

* Allow to choose multiple classes

* Draw rect in the color of the corresponding class

* Write detector window classes to ostream; also group detection windows by class (when ostreaming)

* Train a separate instance segmentation network for each classlabel

* Use separate synchronization file for each seg net of each class

* Allow more overlap

* Fix sorting criterion

* Fix interpolating the predicted mask

* Improve bilinear interpolation: if output type is an integer, round instead of truncating

* Add helpful comments

* Ignore large aspect ratios; refactor the code; tweak some network parameters

* Simplify the segmentation network structure; make the object detection network more complex in turn

* Problem: CUDA errors not reported properly to console
Solution: stop and join data loader threads even in case of exceptions

* Minor parameters tweaking

* Loss may have increased, even if prob_loss_increasing_thresh > prob_loss_increasing_thresh_max_value

* Add previous_loss_values_dump_amount to previous_loss_values.size() when deciding if loss has been increasing

* Improve behaviour when loss actually increased after disk sync

* Revert some of the earlier change

* Disregard dumped loss values only when deciding if learning rate should be shrunk, but *not* when deciding if loss has been going up since last disk sync

* Revert "Revert some of the earlier change"

This reverts commit 6c852124efe6473a5c962de0091709129d6fcde3.

* Keep enough previous loss values, until the disk sync

* Fix maintaining the dumped (now "effectively disregarded") loss values count

* Detect cats instead of aeroplanes

* Add helpful logging

* Clarify the intention and the code

* Review fixes

* Add operator== for the other pixel types as well; remove the inline

* If available, use constexpr if

* Revert "If available, use constexpr if"

This reverts commit 503d4dd3355ff8ad613116e3ffcc0fa664674f69.

* Simplify code as per review comments

* Keep estimating steps_without_progress, even if steps_since_last_learning_rate_shrink < iter_without_progress_thresh

* Clarify console output

* Revert "Keep estimating steps_without_progress, even if steps_since_last_learning_rate_shrink < iter_without_progress_thresh"

This reverts commit 9191ebc7762d17d81cdfc334a80ca9a667365740.

* To keep the changes to a bare minimum, revert the steps_since_last_learning_rate_shrink change after all (at least for now)

* Even empty out some of the previous test loss values

* Minor review fixes

* Can't use C++14 features here

* Do not use the struct name as a variable name
parent 7f2be82e
......@@ -993,6 +993,36 @@ namespace dlib
}
}
inline std::ostream& operator<<(std::ostream& out, const std::vector<mmod_options::detector_window_details>& detector_windows)
{
// write detector windows grouped by label
// example output: aeroplane:74x30,131x30,70x45,54x70,198x30;bicycle:70x57,32x70,70x32,51x70,128x30,30x121;car:70x36,70x60,99x30,52x70,30x83,30x114,30x200
std::map<std::string, std::deque<mmod_options::detector_window_details>> detector_windows_by_label;
for (const auto& detector_window : detector_windows)
detector_windows_by_label[detector_window.label].push_back(detector_window);
size_t label_count = 0;
for (const auto& i : detector_windows_by_label)
{
const auto& label = i.first;
const auto& detector_windows = i.second;
if (label_count++ > 0)
out << ";";
out << label << ":";
for (size_t j = 0; j < detector_windows.size(); ++j)
{
out << detector_windows[j].width << "x" << detector_windows[j].height;
if (j + 1 < detector_windows.size())
out << ",";
}
}
return out;
}
// ----------------------------------------------------------------------------------------
class loss_mmod_
......@@ -1395,15 +1425,10 @@ namespace dlib
{
out << "loss_mmod\t (";
out << "detector_windows:(";
auto& opts = item.options;
for (size_t i = 0; i < opts.detector_windows.size(); ++i)
{
out << opts.detector_windows[i].width << "x" << opts.detector_windows[i].height;
if (i+1 < opts.detector_windows.size())
out << ",";
}
out << ")";
out << "detector_windows:(" << opts.detector_windows << ")";
out << ", loss per FA:" << opts.loss_per_false_alarm;
out << ", loss per miss:" << opts.loss_per_missed_target;
out << ", truth match IOU thresh:" << opts.truth_match_iou_threshold;
......
......@@ -460,6 +460,7 @@ namespace dlib
test_steps_without_progress = 0;
previous_loss_values.clear();
test_previous_loss_values.clear();
previous_loss_values_to_keep_until_disk_sync.clear();
}
learning_rate = lr;
lr_schedule.set_size(0);
......@@ -602,6 +603,11 @@ namespace dlib
// discard really old loss values.
while (previous_loss_values.size() > iter_without_progress_thresh)
previous_loss_values.pop_front();
// separately keep another loss history until disk sync
// (but only if disk sync is enabled)
if (!sync_filename.empty())
previous_loss_values_to_keep_until_disk_sync.push_back(loss);
}
template <typename T>
......@@ -700,10 +706,10 @@ namespace dlib
// optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate;
test_steps_without_progress = 0;
// Empty out some of the previous loss values so that test_steps_without_progress
// will decrease below test_iter_without_progress_thresh.
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount+test_iter_without_progress_thresh/10 && test_previous_loss_values.size() > 0; ++cnt)
test_previous_loss_values.pop_front();
drop_some_test_previous_loss_values();
}
}
}
......@@ -820,10 +826,10 @@ namespace dlib
// optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate;
steps_without_progress = 0;
// Empty out some of the previous loss values so that steps_without_progress
// will decrease below iter_without_progress_thresh.
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount+iter_without_progress_thresh/10 && previous_loss_values.size() > 0; ++cnt)
previous_loss_values.pop_front();
drop_some_previous_loss_values();
}
}
}
......@@ -895,7 +901,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 12;
int version = 13;
serialize(version, out);
size_t nl = dnn_trainer::num_layers;
......@@ -924,14 +930,14 @@ namespace dlib
serialize(item.test_previous_loss_values, out);
serialize(item.previous_loss_values_dump_amount, out);
serialize(item.test_previous_loss_values_dump_amount, out);
serialize(item.previous_loss_values_to_keep_until_disk_sync, out);
}
friend void deserialize(dnn_trainer& item, std::istream& in)
{
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 12)
if (version != 13)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0;
......@@ -970,6 +976,7 @@ namespace dlib
deserialize(item.test_previous_loss_values, in);
deserialize(item.previous_loss_values_dump_amount, in);
deserialize(item.test_previous_loss_values_dump_amount, in);
deserialize(item.previous_loss_values_to_keep_until_disk_sync, in);
if (item.devices.size() > 1)
{
......@@ -987,6 +994,20 @@ namespace dlib
}
}
// Empty out some of the previous loss values so that steps_without_progress will decrease below iter_without_progress_thresh.
void drop_some_previous_loss_values()
{
for (unsigned long cnt = 0; cnt < previous_loss_values_dump_amount + iter_without_progress_thresh / 10 && previous_loss_values.size() > 0; ++cnt)
previous_loss_values.pop_front();
}
// Empty out some of the previous test loss values so that test_steps_without_progress will decrease below test_iter_without_progress_thresh.
void drop_some_test_previous_loss_values()
{
for (unsigned long cnt = 0; cnt < test_previous_loss_values_dump_amount + test_iter_without_progress_thresh / 10 && test_previous_loss_values.size() > 0; ++cnt)
test_previous_loss_values.pop_front();
}
void sync_to_disk (
bool do_it_now = false
)
......@@ -1020,6 +1041,20 @@ namespace dlib
sync_file_reloaded = true;
if (verbose)
std::cout << "Loss has been increasing, reloading saved state from " << newest_syncfile() << std::endl;
// Are we repeatedly hitting our head against the wall? If so, then we
// might be better off giving up at this learning rate, and trying a
// lower one instead.
if (prob_loss_increasing_thresh >= prob_loss_increasing_thresh_max_value)
{
std::cout << "(and while at it, also shrinking the learning rate)" << std::endl;
learning_rate = learning_rate_shrink * learning_rate;
steps_without_progress = 0;
test_steps_without_progress = 0;
drop_some_previous_loss_values();
drop_some_test_previous_loss_values();
}
}
else
{
......@@ -1057,34 +1092,39 @@ namespace dlib
if (!std::ifstream(newest_syncfile(), std::ios::binary))
return false;
for (auto x : previous_loss_values)
// Now look at the data since a little before the last disk sync. We will
// check if the loss is getting better or worse.
while (previous_loss_values_to_keep_until_disk_sync.size() > 2 * gradient_updates_since_last_sync)
previous_loss_values_to_keep_until_disk_sync.pop_front();
running_gradient g;
for (auto x : previous_loss_values_to_keep_until_disk_sync)
{
// If we get a NaN value of loss assume things have gone horribly wrong and
// we should reload the state of the trainer.
if (std::isnan(x))
return true;
g.add(x);
}
// if we haven't seen much data yet then just say false. Or, alternatively, if
// it's been too long since the last sync then don't reload either.
if (gradient_updates_since_last_sync < 30 || previous_loss_values.size() < 2*gradient_updates_since_last_sync)
// if we haven't seen much data yet then just say false.
if (gradient_updates_since_last_sync < 30)
return false;
// Now look at the data since a little before the last disk sync. We will
// check if the loss is getting bettor or worse.
running_gradient g;
for (size_t i = previous_loss_values.size() - 2*gradient_updates_since_last_sync; i < previous_loss_values.size(); ++i)
g.add(previous_loss_values[i]);
// if the loss is very likely to be increasing then return true
const double prob = g.probability_gradient_greater_than(0);
if (prob > prob_loss_increasing_thresh && prob_loss_increasing_thresh <= prob_loss_increasing_thresh_max_value)
if (prob > prob_loss_increasing_thresh)
{
// Exponentially decay the threshold towards 1 so that if we keep finding
// the loss to be increasing over and over we will make the test
// progressively harder and harder until it fails, therefore ensuring we
// can't get stuck reloading from a previous state over and over.
prob_loss_increasing_thresh = 0.1*prob_loss_increasing_thresh + 0.9*1;
prob_loss_increasing_thresh = std::min(
0.1*prob_loss_increasing_thresh + 0.9*1,
prob_loss_increasing_thresh_max_value
);
return true;
}
else
......@@ -1247,6 +1287,8 @@ namespace dlib
std::atomic<unsigned long> test_steps_without_progress;
std::deque<double> test_previous_loss_values;
std::deque<double> previous_loss_values_to_keep_until_disk_sync;
std::atomic<double> learning_rate_shrink;
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
std::string sync_filename;
......
......@@ -865,10 +865,18 @@ namespace dlib
float fout[4];
out.store(fout);
out_img[r][c] = static_cast<T>(fout[0]);
out_img[r][c+1] = static_cast<T>(fout[1]);
out_img[r][c+2] = static_cast<T>(fout[2]);
out_img[r][c+3] = static_cast<T>(fout[3]);
const auto convert_to_output_type = [](float value)
{
if (std::is_integral<T>::value)
return static_cast<T>(value + 0.5);
else
return static_cast<T>(value);
};
out_img[r][c] = convert_to_output_type(fout[0]);
out_img[r][c+1] = convert_to_output_type(fout[1]);
out_img[r][c+2] = convert_to_output_type(fout[2]);
out_img[r][c+3] = convert_to_output_type(fout[3]);
}
x = -x_scale + c*x_scale;
for (; c < out_img.nc(); ++c)
......
......@@ -125,6 +125,13 @@ namespace dlib
unsigned char red;
unsigned char green;
unsigned char blue;
bool operator == (const rgb_pixel& that) const
{
return this->red == that.red
&& this->green == that.green
&& this->blue == that.blue;
}
};
// ----------------------------------------------------------------------------------------
......@@ -151,6 +158,13 @@ namespace dlib
unsigned char blue;
unsigned char green;
unsigned char red;
bool operator == (const bgr_pixel& that) const
{
return this->blue == that.blue
&& this->green == that.green
&& this->red == that.red;
}
};
// ----------------------------------------------------------------------------------------
......@@ -177,6 +191,14 @@ namespace dlib
unsigned char green;
unsigned char blue;
unsigned char alpha;
bool operator == (const rgb_alpha_pixel& that) const
{
return this->red == that.red
&& this->green == that.green
&& this->blue == that.blue
&& this->alpha == that.alpha;
}
};
// ----------------------------------------------------------------------------------------
......@@ -200,6 +222,13 @@ namespace dlib
unsigned char h;
unsigned char s;
unsigned char i;
bool operator == (const hsi_pixel& that) const
{
return this->h == that.h
&& this->s == that.s
&& this->i == that.i;
}
};
// ----------------------------------------------------------------------------------------
......@@ -222,6 +251,13 @@ namespace dlib
unsigned char l;
unsigned char a;
unsigned char b;
bool operator == (const lab_pixel& that) const
{
return this->l == that.l
&& this->a == that.a
&& this->b == that.b;
}
};
// ----------------------------------------------------------------------------------------
......
......@@ -148,8 +148,10 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER)
add_gui_example(dnn_mmod_find_cars2_ex)
add_example(dnn_mmod_train_find_cars_ex)
add_gui_example(dnn_semantic_segmentation_ex)
add_gui_example(dnn_instance_segmentation_ex)
add_example(dnn_imagenet_train_ex)
add_example(dnn_semantic_segmentation_train_ex)
add_example(dnn_instance_segmentation_train_ex)
add_example(dnn_metric_learning_on_images_ex)
endif()
......
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This example shows how to do instance segmentation on an image using net pretrained
on the PASCAL VOC2012 dataset. For an introduction to what instance segmentation is,
see the accompanying header file dnn_instance_segmentation_ex.h.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_instance_segmentation_train_ex example program.
3. Run:
./dnn_instance_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_instance_segmentation_ex example program.
6. Run:
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#include "dnn_instance_segmentation_ex.h"
#include "pascal_voc_2012.h"
#include <iostream>
#include <dlib/data_io.h>
#include <dlib/gui_widgets.h>
using namespace std;
using namespace dlib;
// ----------------------------------------------------------------------------------------
int main(int argc, char** argv) try
{
if (argc != 2)
{
cout << "You call this program like this: " << endl;
cout << "./dnn_instance_segmentation_train_ex /path/to/images" << endl;
cout << endl;
cout << "You will also need a trained '" << instance_segmentation_net_filename << "' file." << endl;
cout << "You can either train it yourself (see example program" << endl;
cout << "dnn_instance_segmentation_train_ex), or download a" << endl;
cout << "copy from here: http://dlib.net/files/" << instance_segmentation_net_filename << endl;
return 1;
}
// Read the file containing the trained networks from the working directory.
det_anet_type det_net;
std::map<std::string, seg_bnet_type> seg_nets_by_class;
deserialize(instance_segmentation_net_filename) >> det_net >> seg_nets_by_class;
// Show inference results in a window.
image_window win;
matrix<rgb_pixel> input_image;
// Find supported image files.
const std::vector<file> files = dlib::get_files_in_directory_tree(argv[1],
dlib::match_endings(".jpeg .jpg .png"));
dlib::rand rnd;
cout << "Found " << files.size() << " images, processing..." << endl;
for (const file& file : files)
{
// Load the input image.
load_image(input_image, file.full_name());
// Draw largest objects last
const auto sort_instances = [](const std::vector<mmod_rect>& input) {
auto output = input;
const auto compare_area = [](const mmod_rect& lhs, const mmod_rect& rhs) {
return lhs.rect.area() < rhs.rect.area();
};
std::sort(output.begin(), output.end(), compare_area);
return output;
};
// Find instances in the input image
const auto instances = sort_instances(det_net(input_image));
matrix<rgb_pixel> rgb_label_image;
matrix<rgb_pixel> input_chip;
rgb_label_image.set_size(input_image.nr(), input_image.nc());
rgb_label_image = rgb_pixel(0, 0, 0);
bool found_something = false;
for (const auto& instance : instances)
{
if (!found_something)
{
cout << "Found ";
found_something = true;
}
else
{
cout << ", ";
}
cout << instance.label;
const auto cropping_rect = get_cropping_rect(instance.rect);
const chip_details chip_details(cropping_rect, chip_dims(seg_dim, seg_dim));
extract_image_chip(input_image, chip_details, input_chip, interpolate_bilinear());
const auto i = seg_nets_by_class.find(instance.label);
if (i == seg_nets_by_class.end())
{
// per-class segmentation net not found, so we must be using the same net for all classes
// (see bool separate_seg_net_for_each_class in dnn_instance_segmentation_train_ex.cpp)
DLIB_CASSERT(seg_nets_by_class.size() == 1);
DLIB_CASSERT(seg_nets_by_class.begin()->first == "");
}
auto& seg_net = i != seg_nets_by_class.end()
? i->second // use the segmentation net trained for this class
: seg_nets_by_class.begin()->second; // use the same segmentation net for all classes
const auto mask = seg_net(input_chip);
const rgb_pixel random_color(
rnd.get_random_8bit_number(),
rnd.get_random_8bit_number(),
rnd.get_random_8bit_number()
);
dlib::matrix<uint16_t> resized_mask(
static_cast<int>(chip_details.rect.height()),
static_cast<int>(chip_details.rect.width())
);
dlib::resize_image(mask, resized_mask);
for (int r = 0; r < resized_mask.nr(); ++r)
{
for (int c = 0; c < resized_mask.nc(); ++c)
{
if (resized_mask(r, c))
{
const auto y = chip_details.rect.top() + r;
const auto x = chip_details.rect.left() + c;
if (y >= 0 && y < rgb_label_image.nr() && x >= 0 && x < rgb_label_image.nc())
rgb_label_image(y, x) = random_color;
}
}
}
const Voc2012class& voc2012_class = find_voc2012_class(
[&instance](const Voc2012class& candidate) {
return candidate.classlabel == instance.label;
}
);
dlib::draw_rectangle(rgb_label_image, instance.rect, voc2012_class.rgb_label, 1);
}
// Show the input image on the left, and the predicted RGB labels on the right.
win.set_image(join_rows(input_image, rgb_label_image));
if (!instances.empty())
{
cout << " in " << file.name() << " - hit enter to process the next image";
cin.get();
}
}
}
catch(std::exception& e)
{
cout << e.what() << endl;
}
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
Instance segmentation using the PASCAL VOC2012 dataset.
Instance segmentation sort-of combines object detection with semantic
segmentation. While each dog, for example, is detected separately,
the output is not only a bounding-box but a more accurate, per-pixel
mask.
For introductions to object detection and semantic segmentation, you
can have a look at dnn_mmod_ex.cpp and dnn_semantic_segmentation.h,
respectively.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_instance_segmentation_train_ex example program.
3. Run:
./dnn_instance_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_instance_segmentation_ex example program.
6. Run:
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/instance_segmentation_voc2012net.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#ifndef DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_
#define DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_
#include <dlib/dnn.h>
// ----------------------------------------------------------------------------------------
namespace {
// Segmentation will be performed using patches having this size.
constexpr int seg_dim = 227;
}
dlib::rectangle get_cropping_rect(const dlib::rectangle& rectangle)
{
DLIB_ASSERT(!rectangle.is_empty());
const auto center_point = dlib::center(rectangle);
const auto max_dim = std::max(rectangle.width(), rectangle.height());
const auto d = static_cast<long>(std::round(max_dim / 2.0 * 1.5)); // add +50%
return dlib::rectangle(
center_point.x() - d,
center_point.y() - d,
center_point.x() + d,
center_point.y() + d
);
}
// ----------------------------------------------------------------------------------------
// The object detection network.
// Adapted from dnn_mmod_train_find_cars_ex.cpp and friends.
template <long num_filters, typename SUBNET> using con5d = dlib::con<num_filters,5,5,2,2,SUBNET>;
template <long num_filters, typename SUBNET> using con5 = dlib::con<num_filters,5,5,1,1,SUBNET>;
template <typename SUBNET> using bdownsampler = dlib::relu<dlib::bn_con<con5d<128,dlib::relu<dlib::bn_con<con5d<128,dlib::relu<dlib::bn_con<con5d<32,SUBNET>>>>>>>>>;
template <typename SUBNET> using adownsampler = dlib::relu<dlib::affine<con5d<128,dlib::relu<dlib::affine<con5d<128,dlib::relu<dlib::affine<con5d<32,SUBNET>>>>>>>>>;
template <typename SUBNET> using brcon5 = dlib::relu<dlib::bn_con<con5<256,SUBNET>>>;
template <typename SUBNET> using arcon5 = dlib::relu<dlib::affine<con5<256,SUBNET>>>;
using det_bnet_type = dlib::loss_mmod<dlib::con<1,9,9,1,1,brcon5<brcon5<brcon5<bdownsampler<dlib::input_rgb_image_pyramid<dlib::pyramid_down<6>>>>>>>>;
using det_anet_type = dlib::loss_mmod<dlib::con<1,9,9,1,1,arcon5<arcon5<arcon5<adownsampler<dlib::input_rgb_image_pyramid<dlib::pyramid_down<6>>>>>>>>;
// The segmentation network.
// For the time being, this is very much copy-paste from dnn_semantic_segmentation.h, although the network is made narrower (smaller feature maps).
template <int N, template <typename> class BN, int stride, typename SUBNET>
using block = BN<dlib::con<N,3,3,1,1,dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
template <int N, template <typename> class BN, int stride, typename SUBNET>
using blockt = BN<dlib::cont<N,3,3,1,1,dlib::relu<BN<dlib::cont<N,3,3,stride,stride,SUBNET>>>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual = dlib::add_prev1<block<N,BN,1,dlib::tag1<SUBNET>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual_down = dlib::add_prev2<dlib::avg_pool<2,2,2,2,dlib::skip1<dlib::tag2<block<N,BN,2,dlib::tag1<SUBNET>>>>>>;
template <template <int,template<typename>class,int,typename> class block, int N, template<typename>class BN, typename SUBNET>
using residual_up = dlib::add_prev2<dlib::cont<N,2,2,2,2,dlib::skip1<dlib::tag2<blockt<N,BN,2,dlib::tag1<SUBNET>>>>>>;
template <int N, typename SUBNET> using res = dlib::relu<residual<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares = dlib::relu<residual<block,N,dlib::affine,SUBNET>>;
template <int N, typename SUBNET> using res_down = dlib::relu<residual_down<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares_down = dlib::relu<residual_down<block,N,dlib::affine,SUBNET>>;
template <int N, typename SUBNET> using res_up = dlib::relu<residual_up<block,N,dlib::bn_con,SUBNET>>;
template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block,N,dlib::affine,SUBNET>>;
// ----------------------------------------------------------------------------------------
template <typename SUBNET> using res16 = res<16,SUBNET>;
template <typename SUBNET> using res24 = res<24,SUBNET>;
template <typename SUBNET> using res32 = res<32,SUBNET>;
template <typename SUBNET> using res48 = res<48,SUBNET>;
template <typename SUBNET> using ares16 = ares<16,SUBNET>;
template <typename SUBNET> using ares24 = ares<24,SUBNET>;
template <typename SUBNET> using ares32 = ares<32,SUBNET>;
template <typename SUBNET> using ares48 = ares<48,SUBNET>;
template <typename SUBNET> using level1 = dlib::repeat<2,res16,res<16,SUBNET>>;
template <typename SUBNET> using level2 = dlib::repeat<2,res24,res_down<24,SUBNET>>;
template <typename SUBNET> using level3 = dlib::repeat<2,res32,res_down<32,SUBNET>>;
template <typename SUBNET> using level4 = dlib::repeat<2,res48,res_down<48,SUBNET>>;
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares16,ares<16,SUBNET>>;
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares24,ares_down<24,SUBNET>>;
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares32,ares_down<32,SUBNET>>;
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares48,ares_down<48,SUBNET>>;
template <typename SUBNET> using level1t = dlib::repeat<2,res16,res_up<16,SUBNET>>;
template <typename SUBNET> using level2t = dlib::repeat<2,res24,res_up<24,SUBNET>>;
template <typename SUBNET> using level3t = dlib::repeat<2,res32,res_up<32,SUBNET>>;
template <typename SUBNET> using level4t = dlib::repeat<2,res48,res_up<48,SUBNET>>;
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares16,ares_up<16,SUBNET>>;
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares24,ares_up<24,SUBNET>>;
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares32,ares_up<32,SUBNET>>;
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares48,ares_up<48,SUBNET>>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class TAGGED,
template<typename> class PREV_RESIZED,
typename SUBNET
>
using resize_and_concat = dlib::add_layer<
dlib::concat_<TAGGED,PREV_RESIZED>,
PREV_RESIZED<dlib::resize_prev_to_tagged<TAGGED,SUBNET>>>;
template <typename SUBNET> using utag1 = dlib::add_tag_layer<2100+1,SUBNET>;
template <typename SUBNET> using utag2 = dlib::add_tag_layer<2100+2,SUBNET>;
template <typename SUBNET> using utag3 = dlib::add_tag_layer<2100+3,SUBNET>;
template <typename SUBNET> using utag4 = dlib::add_tag_layer<2100+4,SUBNET>;
template <typename SUBNET> using utag1_ = dlib::add_tag_layer<2110+1,SUBNET>;
template <typename SUBNET> using utag2_ = dlib::add_tag_layer<2110+2,SUBNET>;
template <typename SUBNET> using utag3_ = dlib::add_tag_layer<2110+3,SUBNET>;
template <typename SUBNET> using utag4_ = dlib::add_tag_layer<2110+4,SUBNET>;
template <typename SUBNET> using concat_utag1 = resize_and_concat<utag1,utag1_,SUBNET>;
template <typename SUBNET> using concat_utag2 = resize_and_concat<utag2,utag2_,SUBNET>;
template <typename SUBNET> using concat_utag3 = resize_and_concat<utag3,utag3_,SUBNET>;
template <typename SUBNET> using concat_utag4 = resize_and_concat<utag4,utag4_,SUBNET>;
// ----------------------------------------------------------------------------------------
static const char* instance_segmentation_net_filename = "instance_segmentation_voc2012net.dnn";
// ----------------------------------------------------------------------------------------
// training network type
using seg_bnet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<2,1,1,1,1,
dlib::relu<dlib::bn_con<dlib::cont<16,7,7,2,2,
concat_utag1<level1t<
concat_utag2<level2t<
concat_utag3<level3t<
concat_utag4<level4t<
level4<utag4<
level3<utag3<
level2<utag2<
level1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::bn_con<dlib::con<16,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>>>>>>>>>>>>;
// testing network type (replaced batch normalization with fixed affine transforms)
using seg_anet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<2,1,1,1,1,
dlib::relu<dlib::affine<dlib::cont<16,7,7,2,2,
concat_utag1<alevel1t<
concat_utag2<alevel2t<
concat_utag3<alevel3t<
concat_utag4<alevel4t<
alevel4<utag4<
alevel3<utag3<
alevel2<utag2<
alevel1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::affine<dlib::con<16,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>>>>>>>>>>>>;
// ----------------------------------------------------------------------------------------
#endif // DLIB_DNn_INSTANCE_SEGMENTATION_EX_H_
This diff is collapsed.
......@@ -34,83 +34,7 @@
#define DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_
#include <dlib/dnn.h>
// ----------------------------------------------------------------------------------------
inline bool operator == (const dlib::rgb_pixel& a, const dlib::rgb_pixel& b)
{
return a.red == b.red && a.green == b.green && a.blue == b.blue;
}
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network.
struct Voc2012class {
Voc2012class(uint16_t index, const dlib::rgb_pixel& rgb_label, const std::string& classlabel)
: index(index), rgb_label(rgb_label), classlabel(classlabel)
{}
// The index of the class. In the PASCAL VOC 2012 dataset, indexes from 0 to 20 are valid.
const uint16_t index = 0;
// The corresponding RGB representation of the class.
const dlib::rgb_pixel rgb_label;
// The label of the class in plain text.
const std::string classlabel;
};
namespace {
constexpr int class_count = 21; // background + 20 classes
const std::vector<Voc2012class> classes = {
Voc2012class(0, dlib::rgb_pixel(0, 0, 0), ""), // background
// The cream-colored `void' label is used in border regions and to mask difficult objects
// (see http://host.robots.ox.ac.uk/pascal/VOC/voc2012/htmldoc/devkit_doc.html)
Voc2012class(dlib::loss_multiclass_log_per_pixel_::label_to_ignore,
dlib::rgb_pixel(224, 224, 192), "border"),
Voc2012class(1, dlib::rgb_pixel(128, 0, 0), "aeroplane"),
Voc2012class(2, dlib::rgb_pixel( 0, 128, 0), "bicycle"),
Voc2012class(3, dlib::rgb_pixel(128, 128, 0), "bird"),
Voc2012class(4, dlib::rgb_pixel( 0, 0, 128), "boat"),
Voc2012class(5, dlib::rgb_pixel(128, 0, 128), "bottle"),
Voc2012class(6, dlib::rgb_pixel( 0, 128, 128), "bus"),
Voc2012class(7, dlib::rgb_pixel(128, 128, 128), "car"),
Voc2012class(8, dlib::rgb_pixel( 64, 0, 0), "cat"),
Voc2012class(9, dlib::rgb_pixel(192, 0, 0), "chair"),
Voc2012class(10, dlib::rgb_pixel( 64, 128, 0), "cow"),
Voc2012class(11, dlib::rgb_pixel(192, 128, 0), "diningtable"),
Voc2012class(12, dlib::rgb_pixel( 64, 0, 128), "dog"),
Voc2012class(13, dlib::rgb_pixel(192, 0, 128), "horse"),
Voc2012class(14, dlib::rgb_pixel( 64, 128, 128), "motorbike"),
Voc2012class(15, dlib::rgb_pixel(192, 128, 128), "person"),
Voc2012class(16, dlib::rgb_pixel( 0, 64, 0), "pottedplant"),
Voc2012class(17, dlib::rgb_pixel(128, 64, 0), "sheep"),
Voc2012class(18, dlib::rgb_pixel( 0, 192, 0), "sofa"),
Voc2012class(19, dlib::rgb_pixel(128, 192, 0), "train"),
Voc2012class(20, dlib::rgb_pixel( 0, 64, 128), "tvmonitor"),
};
}
template <typename Predicate>
const Voc2012class& find_voc2012_class(Predicate predicate)
{
const auto i = std::find_if(classes.begin(), classes.end(), predicate);
if (i != classes.end())
{
return *i;
}
else
{
throw std::runtime_error("Unable to find a matching VOC2012 class");
}
}
#include "pascal_voc_2012.h"
// ----------------------------------------------------------------------------------------
......
......@@ -91,107 +91,6 @@ void randomly_crop_image (
// ----------------------------------------------------------------------------------------
// The names of the input image and the associated RGB label image in the PASCAL VOC 2012
// data set.
struct image_info
{
string image_filename;
string label_filename;
};
// Read the list of image files belonging to either the "train", "trainval", or "val" set
// of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_listing(
const std::string& voc2012_folder,
const std::string& file = "train" // "train", "trainval", or "val"
)
{
std::ifstream in(voc2012_folder + "/ImageSets/Segmentation/" + file + ".txt");
std::vector<image_info> results;
while (in)
{
std::string basename;
in >> basename;
if (!basename.empty())
{
image_info image_info;
image_info.image_filename = voc2012_folder + "/JPEGImages/" + basename + ".jpg";
image_info.label_filename = voc2012_folder + "/SegmentationClass/" + basename + ".png";
results.push_back(image_info);
}
}
return results;
}
// Read the list of image files belong to the "train" set of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_train_listing(
const std::string& voc2012_folder
)
{
return get_pascal_voc2012_listing(voc2012_folder, "train");
}
// Read the list of image files belong to the "val" set of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_val_listing(
const std::string& voc2012_folder
)
{
return get_pascal_voc2012_listing(voc2012_folder, "val");
}
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network. To convert the ground-truth data to
// something that the network can efficiently digest, we need to be able to map the RGB
// values to the corresponding indexes.
// Given an RGB representation, find the corresponding PASCAL VOC2012 class
// (e.g., 'dog').
const Voc2012class& find_voc2012_class(const dlib::rgb_pixel& rgb_label)
{
return find_voc2012_class(
[&rgb_label](const Voc2012class& voc2012class)
{
return rgb_label == voc2012class.rgb_label;
}
);
}
// Convert an RGB class label to an index in the range [0, 20].
inline uint16_t rgb_label_to_index_label(const dlib::rgb_pixel& rgb_label)
{
return find_voc2012_class(rgb_label).index;
}
// Convert an image containing RGB class labels to a corresponding
// image containing indexes in the range [0, 20].
void rgb_label_image_to_index_label_image(
const dlib::matrix<dlib::rgb_pixel>& rgb_label_image,
dlib::matrix<uint16_t>& index_label_image
)
{
const long nr = rgb_label_image.nr();
const long nc = rgb_label_image.nc();
index_label_image.set_size(nr, nc);
for (long r = 0; r < nr; ++r)
{
for (long c = 0; c < nc; ++c)
{
index_label_image(r, c) = rgb_label_to_index_label(rgb_label_image(r, c));
}
}
}
// ----------------------------------------------------------------------------------------
// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter.
double calculate_accuracy(anet_type& anet, const std::vector<image_info>& dataset)
{
......@@ -209,14 +108,14 @@ double calculate_accuracy(anet_type& anet, const std::vector<image_info>& datase
load_image(input_image, image_info.image_filename);
// Load the ground-truth (RGB) labels.
load_image(rgb_label_image, image_info.label_filename);
load_image(rgb_label_image, image_info.class_label_filename);
// Create predictions for each pixel. At this point, the type of each prediction
// is an index (a value between 0 and 20). Note that the net may return an image
// that is not exactly the same size as the input.
const matrix<uint16_t> temp = anet(input_image);
// Convert the indexes to RGB values.
// Convert the RGB values to indexes.
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);
// Crop the net output to be exactly the same size as the input.
......@@ -324,9 +223,9 @@ int main(int argc, char** argv) try
load_image(input_image, image_info.image_filename);
// Load the ground-truth (RGB) labels.
load_image(rgb_label_image, image_info.label_filename);
load_image(rgb_label_image, image_info.class_label_filename);
// Convert the indexes to RGB values.
// Convert the RGB values to indexes.
rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);
// Randomly pick a part of the image.
......
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
Helper definitions for working with the PASCAL VOC2012 dataset.
*/
#ifndef PASCAL_VOC_2012_H_
#define PASCAL_VOC_2012_H_
#include <dlib/pixel.h>
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network. To convert the ground-truth data to
// something that the network can efficiently digest, we need to be able to map the RGB
// values to the corresponding indexes.
struct Voc2012class {
Voc2012class(uint16_t index, const dlib::rgb_pixel& rgb_label, const std::string& classlabel)
: index(index), rgb_label(rgb_label), classlabel(classlabel)
{}
// The index of the class. In the PASCAL VOC 2012 dataset, indexes from 0 to 20 are valid.
const uint16_t index = 0;
// The corresponding RGB representation of the class.
const dlib::rgb_pixel rgb_label;
// The label of the class in plain text.
const std::string classlabel;
};
namespace {
constexpr int class_count = 21; // background + 20 classes
const std::vector<Voc2012class> classes = {
Voc2012class(0, dlib::rgb_pixel(0, 0, 0), ""), // background
// The cream-colored `void' label is used in border regions and to mask difficult objects
// (see http://host.robots.ox.ac.uk/pascal/VOC/voc2012/htmldoc/devkit_doc.html)
Voc2012class(dlib::loss_multiclass_log_per_pixel_::label_to_ignore,
dlib::rgb_pixel(224, 224, 192), "border"),
Voc2012class(1, dlib::rgb_pixel(128, 0, 0), "aeroplane"),
Voc2012class(2, dlib::rgb_pixel( 0, 128, 0), "bicycle"),
Voc2012class(3, dlib::rgb_pixel(128, 128, 0), "bird"),
Voc2012class(4, dlib::rgb_pixel( 0, 0, 128), "boat"),
Voc2012class(5, dlib::rgb_pixel(128, 0, 128), "bottle"),
Voc2012class(6, dlib::rgb_pixel( 0, 128, 128), "bus"),
Voc2012class(7, dlib::rgb_pixel(128, 128, 128), "car"),
Voc2012class(8, dlib::rgb_pixel( 64, 0, 0), "cat"),
Voc2012class(9, dlib::rgb_pixel(192, 0, 0), "chair"),
Voc2012class(10, dlib::rgb_pixel( 64, 128, 0), "cow"),
Voc2012class(11, dlib::rgb_pixel(192, 128, 0), "diningtable"),
Voc2012class(12, dlib::rgb_pixel( 64, 0, 128), "dog"),
Voc2012class(13, dlib::rgb_pixel(192, 0, 128), "horse"),
Voc2012class(14, dlib::rgb_pixel( 64, 128, 128), "motorbike"),
Voc2012class(15, dlib::rgb_pixel(192, 128, 128), "person"),
Voc2012class(16, dlib::rgb_pixel( 0, 64, 0), "pottedplant"),
Voc2012class(17, dlib::rgb_pixel(128, 64, 0), "sheep"),
Voc2012class(18, dlib::rgb_pixel( 0, 192, 0), "sofa"),
Voc2012class(19, dlib::rgb_pixel(128, 192, 0), "train"),
Voc2012class(20, dlib::rgb_pixel( 0, 64, 128), "tvmonitor"),
};
}
template <typename Predicate>
const Voc2012class& find_voc2012_class(Predicate predicate)
{
const auto i = std::find_if(classes.begin(), classes.end(), predicate);
if (i != classes.end())
{
return *i;
}
else
{
throw std::runtime_error("Unable to find a matching VOC2012 class");
}
}
// ----------------------------------------------------------------------------------------
// The names of the input image and the associated RGB label image in the PASCAL VOC 2012
// data set.
struct image_info
{
std::string image_filename;
std::string class_label_filename;
std::string instance_label_filename;
};
// Read the list of image files belonging to either the "train", "trainval", or "val" set
// of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_listing(
const std::string& voc2012_folder,
const std::string& file = "train" // "train", "trainval", or "val"
)
{
std::ifstream in(voc2012_folder + "/ImageSets/Segmentation/" + file + ".txt");
std::vector<image_info> results;
while (in)
{
std::string basename;
in >> basename;
if (!basename.empty())
{
image_info info;
info.image_filename = voc2012_folder + "/JPEGImages/" + basename + ".jpg";
info.class_label_filename = voc2012_folder + "/SegmentationClass/" + basename + ".png";
info.instance_label_filename = voc2012_folder + "/SegmentationObject/" + basename + ".png";
results.push_back(info);
}
}
return results;
}
// Read the list of image files belong to the "train" set of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_train_listing(
const std::string& voc2012_folder
)
{
return get_pascal_voc2012_listing(voc2012_folder, "train");
}
// Read the list of image files belong to the "val" set of the PASCAL VOC2012 data.
std::vector<image_info> get_pascal_voc2012_val_listing(
const std::string& voc2012_folder
)
{
return get_pascal_voc2012_listing(voc2012_folder, "val");
}
// Given an RGB representation, find the corresponding PASCAL VOC2012 class
// (e.g., 'dog').
const Voc2012class& find_voc2012_class(const dlib::rgb_pixel& rgb_label)
{
return find_voc2012_class(
[&rgb_label](const Voc2012class& voc2012class)
{
return rgb_label == voc2012class.rgb_label;
}
);
}
// ----------------------------------------------------------------------------------------
// Convert an RGB class label to an index in the range [0, 20].
inline uint16_t rgb_label_to_index_label(const dlib::rgb_pixel& rgb_label)
{
return find_voc2012_class(rgb_label).index;
}
// Convert an image containing RGB class labels to a corresponding
// image containing indexes in the range [0, 20].
void rgb_label_image_to_index_label_image(
const dlib::matrix<dlib::rgb_pixel>& rgb_label_image,
dlib::matrix<uint16_t>& index_label_image
)
{
const long nr = rgb_label_image.nr();
const long nc = rgb_label_image.nc();
index_label_image.set_size(nr, nc);
for (long r = 0; r < nr; ++r)
{
for (long c = 0; c < nc; ++c)
{
index_label_image(r, c) = rgb_label_to_index_label(rgb_label_image(r, c));
}
}
}
#endif // PASCAL_VOC_2012_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment