"tests/python/vscode:/vscode.git/clone" did not exist on "405de769b7faabb19b00a176e0b6dca8a0df3581"
Unverified Commit 16500906 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

YOLO loss (#2376)

parent 951fdd00
...@@ -60,7 +60,13 @@ namespace dlib ...@@ -60,7 +60,13 @@ namespace dlib
fout << "<images>\n"; fout << "<images>\n";
for (unsigned long i = 0; i < images.size(); ++i) for (unsigned long i = 0; i < images.size(); ++i)
{ {
fout << " <image file='" << images[i].filename << "'>\n"; fout << " <image file='" << images[i].filename << "'";
if (images[i].width != 0 && images[i].height != 0)
{
fout << " width='" << images[i].width << "'";
fout << " height='" << images[i].height << "'";
}
fout << ">\n";
// save all the boxes // save all the boxes
for (unsigned long j = 0; j < images[i].boxes.size(); ++j) for (unsigned long j = 0; j < images[i].boxes.size(); ++j)
...@@ -251,6 +257,9 @@ namespace dlib ...@@ -251,6 +257,9 @@ namespace dlib
if (atts.is_in_list("file")) temp_image.filename = atts["file"]; if (atts.is_in_list("file")) temp_image.filename = atts["file"];
else throw dlib::error("<image> missing required attribute 'file'"); else throw dlib::error("<image> missing required attribute 'file'");
if (atts.is_in_list("width")) temp_image.width = sa = atts["width"];
if (atts.is_in_list("height")) temp_image.height = sa = atts["height"];
} }
ts.push_back(name); ts.push_back(name);
......
...@@ -101,7 +101,7 @@ namespace dlib ...@@ -101,7 +101,7 @@ namespace dlib
{ {
/*! /*!
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object represents an annotated image. This object represents an annotated image.
!*/ !*/
image() {} image() {}
...@@ -109,6 +109,8 @@ namespace dlib ...@@ -109,6 +109,8 @@ namespace dlib
std::string filename; std::string filename;
std::vector<box> boxes; std::vector<box> boxes;
long width = 0;
long height = 0;
}; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -3447,6 +3447,529 @@ namespace dlib ...@@ -3447,6 +3447,529 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
struct yolo_options
{
public:
struct anchor_box_details
{
anchor_box_details() = default;
anchor_box_details(unsigned long w, unsigned long h) : width(w), height(h) {}
unsigned long width = 0;
unsigned long height = 0;
friend inline void serialize(const anchor_box_details& item, std::ostream& out)
{
int version = 0;
serialize(version, out);
serialize(item.width, out);
serialize(item.height, out);
}
friend inline void deserialize(anchor_box_details& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
deserialize(item.width, in);
deserialize(item.height, in);
}
};
yolo_options() = default;
template <template <typename> class TAG_TYPE>
void add_anchors(const std::vector<anchor_box_details>& boxes)
{
anchors[tag_id<TAG_TYPE>::id] = boxes;
}
// map between the stride and the anchor boxes
std::map<int, std::vector<anchor_box_details>> anchors;
std::vector<std::string> labels;
double iou_ignore_threshold = 0.7;
double iou_anchor_threshold = 1.0;
test_box_overlap overlaps_nms = test_box_overlap(0.45, 1.0);
bool classwise_nms = true;
double lambda_obj = 1.0;
double lambda_box = 1.0;
double lambda_cls = 1.0;
};
inline void serialize(const yolo_options& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.anchors, out);
serialize(item.labels, out);
serialize(item.iou_ignore_threshold, out);
serialize(item.iou_anchor_threshold, out);
serialize(item.classwise_nms, out);
serialize(item.overlaps_nms, out);
serialize(item.lambda_obj, out);
serialize(item.lambda_box, out);
serialize(item.lambda_cls, out);
}
inline void deserialize(yolo_options& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::yolo_options.");
deserialize(item.anchors, in);
deserialize(item.labels, in);
deserialize(item.iou_ignore_threshold, in);
deserialize(item.iou_anchor_threshold, in);
deserialize(item.classwise_nms, in);
deserialize(item.overlaps_nms, in);
deserialize(item.lambda_obj, in);
deserialize(item.lambda_box, in);
deserialize(item.lambda_cls, in);
}
inline std::ostream& operator<<(std::ostream& out, const std::map<int, std::vector<yolo_options::anchor_box_details>>& anchors)
{
// write anchor boxes grouped by tag id
size_t tag_count = 0;
for (const auto& i : anchors)
{
const auto& tag_id = i.first;
const auto& details = i.second;
if (tag_count++ > 0)
out << ";";
out << "tag" << tag_id << ":";
for (size_t a = 0; a < details.size(); ++a)
{
out << details[a].width << "x" << details[a].height;
if (a + 1 < details.size())
out << ",";
}
}
return out;
}
namespace impl
{
template <template <typename> class TAG_TYPE, template <typename> class... TAG_TYPES>
struct yolo_helper_impl
{
constexpr static size_t tag_count()
{
return 1 + yolo_helper_impl<TAG_TYPES...>::tag_count();
}
static void list_tags(std::ostream& out)
{
out << "tag" << tag_id<TAG_TYPE>::id << (tag_count() > 1 ? "," : "");
yolo_helper_impl<TAG_TYPES...>::list_tags(out);
}
template <typename SUBNET>
static void tensor_to_dets (
const tensor& input_tensor,
const SUBNET& sub,
const long n,
const yolo_options& options,
const double adjust_threshold,
std::vector<yolo_rect>& dets
)
{
yolo_helper_impl<TAG_TYPE>::tensor_to_dets(input_tensor, sub, n, options, adjust_threshold, dets);
yolo_helper_impl<TAG_TYPES...>::tensor_to_dets(input_tensor, sub, n, options, adjust_threshold, dets);
}
template <
typename const_label_iterator,
typename SUBNET
>
static void tensor_to_loss (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub,
const long n,
const yolo_options& options,
double& loss
)
{
yolo_helper_impl<TAG_TYPE>::tensor_to_loss(input_tensor, truth, sub, n, options, loss);
yolo_helper_impl<TAG_TYPES...>::tensor_to_loss(input_tensor, truth, sub, n, options, loss);
}
};
template <template <typename> class TAG_TYPE>
struct yolo_helper_impl<TAG_TYPE>
{
constexpr static size_t tag_count() { return 1; }
static void list_tags(std::ostream& out) { out << "tag" << tag_id<TAG_TYPE>::id; }
template <typename SUBNET>
static void tensor_to_dets (
const tensor& input_tensor,
const SUBNET& sub,
const long n,
const yolo_options& options,
const double adjust_threshold,
std::vector<yolo_rect>& dets
)
{
DLIB_CASSERT(sub.sample_expansion_factor() == 1, sub.sample_expansion_factor());
const auto& anchors = options.anchors.at(tag_id<TAG_TYPE>::id);
const tensor& output_tensor = layer<TAG_TYPE>(sub).get_output();
DLIB_CASSERT(static_cast<size_t>(output_tensor.k()) == anchors.size() * (options.labels.size() + 5));
const auto stride_x = static_cast<double>(input_tensor.nc()) / output_tensor.nc();
const auto stride_y = static_cast<double>(input_tensor.nr()) / output_tensor.nr();
const long num_feats = output_tensor.k() / anchors.size();
const long num_classes = num_feats - 5;
const float* const out_data = output_tensor.host();
for (size_t a = 0; a < anchors.size(); ++a)
{
for (long r = 0; r < output_tensor.nr(); ++r)
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
const float obj = out_data[tensor_index(output_tensor, n, a * num_feats + 4, r, c)];
if (obj > adjust_threshold)
{
const auto x = out_data[tensor_index(output_tensor, n, a * num_feats + 0, r, c)];
const auto y = out_data[tensor_index(output_tensor, n, a * num_feats + 1, r, c)];
const auto w = out_data[tensor_index(output_tensor, n, a * num_feats + 2, r, c)];
const auto h = out_data[tensor_index(output_tensor, n, a * num_feats + 3, r, c)];
yolo_rect det(centered_drect(dpoint((x + c) * stride_x, (y + r) * stride_y),
w / (1 - w) * anchors[a].width,
h / (1 - h) * anchors[a].height));
for (long k = 0; k < num_classes; ++k)
{
const float conf = out_data[tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c)] * obj;
if (conf > adjust_threshold)
det.labels.emplace_back(conf, options.labels[k]);
}
if (!det.labels.empty())
{
std::sort(det.labels.rbegin(), det.labels.rend());
det.detection_confidence = det.labels[0].first;
det.label = det.labels[0].second;
dets.push_back(std::move(det));
}
}
}
}
}
}
template <
typename const_label_iterator,
typename SUBNET
>
static void tensor_to_loss (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub,
const long n,
const yolo_options& options,
double& loss
)
{
DLIB_CASSERT(sub.sample_expansion_factor() == 1, sub.sample_expansion_factor());
const tensor& output_tensor = layer<TAG_TYPE>(sub).get_output();
const auto& anchors = options.anchors.at(tag_id<TAG_TYPE>::id);
DLIB_CASSERT(static_cast<size_t>(output_tensor.k()) == anchors.size() * (options.labels.size() + 5));
const auto stride_x = static_cast<double>(input_tensor.nc()) / output_tensor.nc();
const auto stride_y = static_cast<double>(input_tensor.nr()) / output_tensor.nr();
const long num_feats = output_tensor.k() / anchors.size();
const long num_classes = num_feats - 5;
const float* const out_data = output_tensor.host();
tensor& grad = layer<TAG_TYPE>(sub).get_gradient_input();
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
float* g = grad.host();
// Compute the objectness loss for all grid cells
for (long r = 0; r < output_tensor.nr(); ++r)
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
for (size_t a = 0; a < anchors.size(); ++a)
{
const auto x_idx = tensor_index(output_tensor, n, a * num_feats + 0, r, c);
const auto y_idx = tensor_index(output_tensor, n, a * num_feats + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, a * num_feats + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, a * num_feats + 3, r, c);
const auto o_idx = tensor_index(output_tensor, n, a * num_feats + 4, r, c);
// The prediction at r, c for anchor a
const yolo_rect pred(centered_drect(
dpoint((out_data[x_idx] + c) * stride_x, (out_data[y_idx] + r) * stride_y),
out_data[w_idx] / (1 - out_data[w_idx]) * anchors[a].width,
out_data[h_idx] / (1 - out_data[h_idx]) * anchors[a].height
));
// Find the best IoU for all ground truth boxes
double best_iou = 0;
for (const yolo_rect& truth_box : *truth)
{
if (truth_box.ignore)
continue;
best_iou = std::max(best_iou, box_intersection_over_union(truth_box.rect, pred.rect));
}
// Incur loss for the boxes that are below a certain IoU threshold with any truth box
if (best_iou < options.iou_ignore_threshold)
g[o_idx] = out_data[o_idx];
}
}
}
// Now find the best anchor box for each truth box
for (const yolo_rect& truth_box : *truth)
{
if (truth_box.ignore)
continue;
const dpoint t_center = dcenter(truth_box);
double best_iou = 0;
size_t best_a = 0;
size_t best_tag_id = 0;
for (const auto& item : options.anchors)
{
const auto tag_id = item.first;
const auto details = item.second;
for (size_t a = 0; a < details.size(); ++a)
{
const yolo_rect anchor(centered_drect(t_center, details[a].width, details[a].height));
const double iou = box_intersection_over_union(truth_box.rect, anchor.rect);
if (iou > best_iou)
{
best_iou = iou;
best_a = a;
best_tag_id = tag_id;
}
}
}
for (size_t a = 0; a < anchors.size(); ++a)
{
// Update best anchor if it's from the current stride, and optionally other anchors
if ((best_tag_id == tag_id<TAG_TYPE>::id && best_a == a) || options.iou_anchor_threshold < 1)
{
// do not update other anchors if they have low IoU
if (!(best_tag_id == tag_id<TAG_TYPE>::id && best_a == a))
{
const yolo_rect anchor(centered_drect(t_center, anchors[a].width, anchors[a].height));
const double iou = box_intersection_over_union(truth_box.rect, anchor.rect);
if (iou < options.iou_anchor_threshold)
continue;
}
const long c = t_center.x() / stride_x;
const long r = t_center.y() / stride_y;
const auto x_idx = tensor_index(output_tensor, n, a * num_feats + 0, r, c);
const auto y_idx = tensor_index(output_tensor, n, a * num_feats + 1, r, c);
const auto w_idx = tensor_index(output_tensor, n, a * num_feats + 2, r, c);
const auto h_idx = tensor_index(output_tensor, n, a * num_feats + 3, r, c);
const auto o_idx = tensor_index(output_tensor, n, a * num_feats + 4, r, c);
// This grid cell should detect an object
g[o_idx] = options.lambda_obj * (out_data[o_idx] - 1);
// Get the truth box target values
const double tx = t_center.x() / stride_x - c;
const double ty = t_center.y() / stride_y - r;
const double tw = truth_box.rect.width() / (anchors[a].width + truth_box.rect.width());
const double th = truth_box.rect.height() / (anchors[a].height + truth_box.rect.height());
// Scale regression error according to the truth size
const double scale_box = 2 - truth_box.rect.area() / (input_tensor.nr() * input_tensor.nc());
// Compute the gradient for the box coordinates
g[x_idx] = options.lambda_box * scale_box * (out_data[x_idx] - tx);
g[y_idx] = options.lambda_box * scale_box * (out_data[y_idx] - ty);
g[w_idx] = options.lambda_box * scale_box * (out_data[w_idx] - tw);
g[h_idx] = options.lambda_box * scale_box * (out_data[h_idx] - th);
// Compute the classification error
for (long k = 0; k < num_classes; ++k)
{
const auto c_idx = tensor_index(output_tensor, n, a * num_feats + 5 + k, r, c);
if (truth_box.label == options.labels[k])
g[c_idx] = options.lambda_cls * (out_data[c_idx] - 1);
else
g[c_idx] = options.lambda_cls * out_data[c_idx];
}
}
}
}
// Compute the L2 loss
loss += length_squared(rowm(mat(grad), n));
}
};
}
template <template <typename> class... TAG_TYPES>
class loss_yolo_
{
static void list_tags(std::ostream& out) { impl::yolo_helper_impl<TAG_TYPES...>::list_tags(out); }
public:
typedef std::vector<yolo_rect> training_label_type;
typedef std::vector<yolo_rect> output_label_type;
constexpr static size_t tag_count() { return impl::yolo_helper_impl<TAG_TYPES...>::tag_count(); }
loss_yolo_() {};
loss_yolo_(const yolo_options& options) : options(options) { }
template <
typename SUB_TYPE,
typename label_iterator
>
void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub,
label_iterator iter,
double adjust_threshold = 0.25
) const
{
std::vector<yolo_rect> dets_accum;
std::vector<yolo_rect> final_dets;
for (long i = 0; i < input_tensor.num_samples(); ++i)
{
dets_accum.clear();
impl::yolo_helper_impl<TAG_TYPES...>::tensor_to_dets(input_tensor, sub, i, options, adjust_threshold, dets_accum);
// Do non-max suppression
std::sort(dets_accum.rbegin(), dets_accum.rend());
final_dets.clear();
for (size_t j = 0; j < dets_accum.size(); ++j)
{
if (overlaps_any_box_nms(final_dets, dets_accum[j]))
continue;
final_dets.push_back(dets_accum[j]);
}
*iter++ = std::move(final_dets);
}
}
template <
typename const_label_iterator,
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub
) const
{
DLIB_CASSERT(input_tensor.num_samples() != 0);
double loss = 0;
for (long i = 0; i < input_tensor.num_samples(); ++i)
{
impl::yolo_helper_impl<TAG_TYPES...>::tensor_to_loss(input_tensor, truth, sub, i, options, loss);
++truth;
}
return loss / input_tensor.num_samples();
}
const yolo_options& get_options() const { return options; }
void adjust_nms(double iou_thresh, double percent_covered_thresh = 1, bool classwise = true)
{
options.overlaps_nms = test_box_overlap(iou_thresh, percent_covered_thresh);
options.classwise_nms = classwise;
}
friend void serialize(const loss_yolo_& item, std::ostream& out)
{
serialize("loss_yolo_", out);
size_t count = tag_count();
serialize(count, out);
serialize(item.options, out);
}
friend void deserialize(loss_yolo_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_yolo_")
throw serialization_error("Unexpected version found while deserializing dlib::loss_yolo_.");
size_t count = 0;
deserialize(count, in);
if (count != tag_count())
throw serialization_error("Invalid number of detection tags " + std::to_string(count) +
", while deserializing dlib::loss_yolo_, expecting " +
std::to_string(tag_count()) + "tags instead.");
deserialize(item.options, in);
}
friend std::ostream& operator<<(std::ostream& out, const loss_yolo_& item)
{
out << "loss_yolo\t (";
const auto& opts = item.options;
out << tag_count() << " output" << (tag_count() != 1 ? "s" : "") << ":(";
list_tags(out);
out << ")";
out << ", anchor_boxes:(" << opts.anchors << ")";
out << ", " << opts.labels.size() << " label" << (opts.labels.size() != 1 ? "s" : "") << ":(";
for (size_t i = 0; i < opts.labels.size(); ++i)
{
out << opts.labels[i];
if (i + 1 < opts.labels.size())
out << ",";
}
out << ")";
out << ", iou_ignore_threshold: " << opts.iou_ignore_threshold;
out << ", iou_anchor_threshold: " << opts.iou_anchor_threshold;
out << ", lambda_obj:" << opts.lambda_obj;
out << ", lambda_box:" << opts.lambda_box;
out << ", lambda_cls:" << opts.lambda_cls;
out << ", overlaps_nms:(" << opts.overlaps_nms.get_iou_thresh() << "," << opts.overlaps_nms.get_percent_covered_thresh() << ")";
out << ", classwise_nms:" << std::boolalpha << opts.classwise_nms;
out << ")";
return out;
}
friend void to_xml(const loss_yolo_& /*item*/, std::ostream& out)
{
out << "<loss_yolo/>";
}
private:
yolo_options options;
inline bool overlaps_any_box_nms (
const std::vector<yolo_rect>& boxes,
const yolo_rect& box
) const
{
for (const auto& b : boxes)
{
if (options.overlaps_nms(b.rect, box.rect))
{
if (options.classwise_nms)
{
if (b.label == box.label)
return true;
}
else
{
return true;
}
}
}
return false;
}
};
template <template <typename> class TAG_1, template <typename> class TAG_2, template <typename> class TAG_3, typename SUBNET>
using loss_yolo = add_loss_layer<loss_yolo_<TAG_1, TAG_2, TAG_3>, SUBNET>;
} }
#endif // DLIB_DNn_LOSS_H_ #endif // DLIB_DNn_LOSS_H_
......
...@@ -1852,9 +1852,192 @@ namespace dlib ...@@ -1852,9 +1852,192 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using loss_dot = add_loss_layer<loss_dot_, SUBNET>; using loss_dot = add_loss_layer<loss_dot_, SUBNET>;
// ----------------------------------------------------------------------------------------
struct yolo_options
{
/*!
WHAT THIS OBJECT REPRESENTS
This object contains all the parameters that control the behavior of loss_yolo_.
!*/
public:
struct anchor_box_details
{
anchor_box_details() = default;
anchor_box_details(unsigned long w, unsigned long h) : width(w), height(h) {}
unsigned long width = 0;
unsigned long height = 0;
friend inline void serialize(const anchor_box_details& item, std::ostream& out);
friend inline void deserialize(anchor_box_details& item, std::istream& in);
};
yolo_options() = default;
// This kind of object detector is a multi-scale object detector with bounding box
// regression for anchor boxes. The anchors field determines which anchors will be
// used at the output pointed by the tag layer whose id is the key of the map.
std::unordered_map<int, std::vector<anchor_box_details>> anchors;
template <template <typename> class TAG_TYPE>
void add_anchors(
const std::vector<anchor_box_details>& boxes
);
/*!
ensures
- anchors.at(tag_id<TAG_TYPE>::id) == boxes
!*/
// This field contains the labels of all the possible objects this detector can find.
std::vector<std::string> labels;
// When computing the objectness loss, any detection that has an IoU above
// iou_ignore_threshold with a ground truth box will not incur any loss.
double iou_ignore_threshold = 0.7;
// When computing the YOLO loss (objectness + bounding box regression + classification),
// the best match between a truth and an anchor is always used, regardless of the IoU.
// However, if other anchors have an IoU with a truth box above iou_anchor_threshold, they
// will also experience loss against that truth box as well. Setting iou_anchor_threshold to 1 will
// make the model use only the best anchor for each ground truth, so other anchors can be
// used for other ground truth boxes in the same cell (useful for detecting objects in crowds).
// This setting is meant to be used with "high capacity" models, not small ones.
double iou_anchor_threshold = 1.0;
// When doing non-max suppression, we use overlaps_nms to decide if a box overlaps
// an already output detection and should therefore be thrown out.
test_box_overlap overlaps_nms = test_box_overlap(0.45, 1.0);
// When set to true, NMS will only be applied between objects with the same class label.
bool classwise_nms = true;
// These parameters control how we penalize different kinds of mistakes: notably the objectness loss,
// the box (bounding box regression) loss, and the classification loss.
double lambda_obj = 1.0;
double lambda_box = 1.0;
double lambda_cls = 1.0;
};
void serialize(const yolo_options& item, std::ostream& out)
void deserialize(yolo_options& item, std::istream& in)
// ----------------------------------------------------------------------------------------
template <template <typename> class... TAG_TYPES>
class loss_yolo_
{
/*!
WHAT THIS OBJECT REPRESENTS
This object implements the loss layer interface defined above by
EXAMPLE_LOSS_LAYER_. In particular, it implements the YOLO detection
loss defined in the paper:
YOLOv3: An Incremental Improvement by Joseph Redmon and Ali Farhadi.
This means you use this loss if you want to detect the locations of objects
in images.
It should also be noted that this loss layer requires tag layers as template
parameters, which in turn require a subnetwork to be of type:
layer<TAG_TYPE>(net).subnet(): sig<con<(num_classes + 5) * num_anchors), SUBNET>>
Where num_classes is the number of categories that the detector is trained on,
and num_anchors is the number of priors or anchor boxes at the output pointed
by the tag layer. The number 5 corresponds to the objectness plus the 4 coordinates
for performing bounding box regression.
!*/
public:
typedef std::vector<yolo_rect> training_label_type;
typedef std::vector<yolo_rect> output_label_type;
loss_yolo_(
);
/*!
ensures
- #get_options() == yolo_options()
!*/
loss_yolo_(
yolo_options options_
);
/*!
ensures
- #get_options() == options_
!*/
const yolo_options& get_options (
) const;
/*!
ensures
- returns the options object that defines the general behavior of this loss layer.
!*/
template <
typename SUB_TYPE,
typename label_iterator
>
void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub,
label_iterator iter,
double adjust_threshold = 0.25
) const;
/*!
This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
it has the additional calling requirements that:
- layer<TAG_TYPE>(sub).get_output().k() == options.anchors.at(tag_id<TAG_TYPE>::id).size() * (5 + options.labels.size());
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 1
Also, the output labels are std::vectors of yolo_rects where, for each yolo_rect R,
we have the following interpretations:
- R.rect == the location of an object in the image.
- R.detection_confidence == the score for the object, between 0 and 1. Only
objects with a detection_confidence > adjust_threshold are output. So if
you want to output more objects (that are also of less confidence) you
can call to_label() with a smaller value of adjust_threshold.
- R.label == the label of the detected object.
- R.labels == a std::vector<std::pair<double, std::string>> containing all the confidence values
and labels that have a detection score > adjust_threshold, since this loss allows
for multi-label outputs. Note that the following is true:
- R.labels[0].first == R.detection_confidence
- R.labels[0].second == R.label
- R.ignore == false (this value is unused by to_label()).
!*/
template <
typename const_label_iterator,
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub
) const;
/*!
This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
except it has the additional calling requirements that:
- layer<TAG_TYPE>(sub).get_output().k() == options.anchors.at(tag_id<TAG_TYPE>::id).size() * (5 + options.labels.size());
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 1
Also, the loss value returned corresponds to the squared norm of the error gradient.
!*/
void adjust_nms (
double iou_thresh,
double percent_covered_thresh = 1,
bool classwise = true
);
/*!
ensures
- #get_options().overlaps_nms == test_box_overlap(iou_thresh, percent_covered_thresh)
- #get_options().classwise_nms == classwise
!*/
};
template <typename SUBNET>
using loss_yolo = add_loss_layer<loss_yolo_, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
#endif // DLIB_DNn_LOSS_ABSTRACT_H_ #endif // DLIB_DNn_LOSS_ABSTRACT_H_
...@@ -183,6 +183,60 @@ namespace dlib ...@@ -183,6 +183,60 @@ namespace dlib
item.label = ""; item.label = "";
} }
// ----------------------------------------------------------------------------------------
struct yolo_rect
{
yolo_rect() = default;
yolo_rect(const drectangle& r) : rect(r) {}
yolo_rect(const drectangle& r, double score) : rect(r),detection_confidence(score) {}
yolo_rect(const drectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score), label(label) {}
yolo_rect(const mmod_rect& r) : rect(r.rect), detection_confidence(r.detection_confidence), ignore(r.ignore), label(r.label) {}
drectangle rect;
double detection_confidence = 0;
bool ignore = false;
std::string label;
std::vector<std::pair<double, std::string>> labels;
operator rectangle() const { return rect; }
bool operator == (const yolo_rect& rhs) const
{
return rect == rhs.rect
&& detection_confidence == rhs.detection_confidence
&& ignore == rhs.ignore
&& label == rhs.label;
}
bool operator<(const yolo_rect& rhs) const
{
return detection_confidence < rhs.detection_confidence;
}
};
inline void serialize(const yolo_rect& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.rect, out);
serialize(item.detection_confidence, out);
serialize(item.ignore, out);
serialize(item.label, out);
serialize(item.labels, out);
}
inline void deserialize(yolo_rect& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::yolo_rect");
deserialize(item.rect, in);
deserialize(item.detection_confidence, in);
deserialize(item.ignore, in);
deserialize(item.label, in);
deserialize(item.labels, in);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -194,10 +194,53 @@ namespace dlib ...@@ -194,10 +194,53 @@ namespace dlib
provides serialization support provides serialization support
!*/ !*/
// ----------------------------------------------------------------------------------------
struct yolo_rect
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a simple struct that is used to give training data and receive detections
from the YOLO Detection loss layer loss_yolo_ object.
!*/
yolo_rect() = default;
yolo_rect(const drectangle& r) : rect(r) {}
yolo_rect(const drectangle& r, double score) : rect(r),detection_confidence(score) {}
yolo_rect(const drectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score), label(label) {}
yolo_rect(const mmod_rect& r) : rect(r.rect), detection_confidence(r.detection_confidence), ignore(r.ignore), label(r.label) {}
drectangle rect;
double detection_confidence = 0;
bool ignore = false;
std::string label;
// YOLO detectors are multi label detectors: this field will contain all confidences and labels for a particular detection
std::vector<std::pair<double, std::string>> labels;
operator rectangle() const { return rect; }
bool operator== (const yolo_rect& rhs) const;
/*!
ensures
- returns true if and only if rect == rhs.rect && detection_confidence == rhs.detection_confidence && label == rhs.label.
!*/
bool operator<(const yolo_rect& rhs) const
/*!
ensures
- returns true if and only if detection_confidence < rhs.detection_confidence.
!*/
};
inline void serialize(const yolo_rect& item, std::ostream& out);
inline void deserialize(yolo_rect& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
#endif // DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_Hh_ #endif // DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_Hh_
...@@ -23,6 +23,7 @@ namespace dlib ...@@ -23,6 +23,7 @@ namespace dlib
double max_object_size = 0.7; // cropped object will be at most this fraction of the size of the image. double max_object_size = 0.7; // cropped object will be at most this fraction of the size of the image.
double background_crops_fraction = 0.5; double background_crops_fraction = 0.5;
double translate_amount = 0.10; double translate_amount = 0.10;
double min_object_coverage = 1.0;
std::mutex rnd_mutex; std::mutex rnd_mutex;
dlib::rand rnd; dlib::rand rnd;
...@@ -104,15 +105,26 @@ namespace dlib ...@@ -104,15 +105,26 @@ namespace dlib
max_object_size = value; max_object_size = value;
} }
double get_min_object_coverage (
) const { return min_object_coverage; }
void set_min_object_coverage (
double value
)
{
DLIB_CASSERT(0 < value && value <= 1);
min_object_coverage = value;
}
template < template <
typename array_type typename array_type,
typename rectangle_type
> >
void operator() ( void operator() (
size_t num_crops, size_t num_crops,
const array_type& images, const array_type& images,
const std::vector<std::vector<mmod_rect>>& rects, const std::vector<std::vector<rectangle_type>>& rects,
array_type& crops, array_type& crops,
std::vector<std::vector<mmod_rect>>& crop_rects std::vector<std::vector<rectangle_type>>& crop_rects
) )
{ {
DLIB_CASSERT(images.size() == rects.size()); DLIB_CASSERT(images.size() == rects.size());
...@@ -122,14 +134,15 @@ namespace dlib ...@@ -122,14 +134,15 @@ namespace dlib
} }
template < template <
typename array_type typename array_type,
typename rectangle_type
> >
void append ( void append (
size_t num_crops, size_t num_crops,
const array_type& images, const array_type& images,
const std::vector<std::vector<mmod_rect>>& rects, const std::vector<std::vector<rectangle_type>>& rects,
array_type& crops, array_type& crops,
std::vector<std::vector<mmod_rect>>& crop_rects std::vector<std::vector<rectangle_type>>& crop_rects
) )
{ {
DLIB_CASSERT(images.size() == rects.size()); DLIB_CASSERT(images.size() == rects.size());
...@@ -145,13 +158,14 @@ namespace dlib ...@@ -145,13 +158,14 @@ namespace dlib
template < template <
typename array_type, typename array_type,
typename image_type typename image_type,
typename rectangle_type
> >
void operator() ( void operator() (
const array_type& images, const array_type& images,
const std::vector<std::vector<mmod_rect>>& rects, const std::vector<std::vector<rectangle_type>>& rects,
image_type& crop, image_type& crop,
std::vector<mmod_rect>& crop_rects std::vector<rectangle_type>& crop_rects
) )
{ {
DLIB_CASSERT(images.size() == rects.size()); DLIB_CASSERT(images.size() == rects.size());
...@@ -163,27 +177,29 @@ namespace dlib ...@@ -163,27 +177,29 @@ namespace dlib
} }
template < template <
typename image_type1 typename image_type1,
typename rectangle_type
> >
image_type1 operator() ( image_type1 operator() (
const image_type1& img const image_type1& img
) )
{ {
image_type1 crop; image_type1 crop;
std::vector<mmod_rect> junk1, junk2; std::vector<rectangle_type> junk1, junk2;
(*this)(img, junk1, crop, junk2); (*this)(img, junk1, crop, junk2);
return crop; return crop;
} }
template < template <
typename image_type1, typename image_type1,
typename image_type2 typename image_type2,
typename rectangle_type
> >
void operator() ( void operator() (
const image_type1& img, const image_type1& img,
const std::vector<mmod_rect>& rects, const std::vector<rectangle_type>& rects,
image_type2& crop, image_type2& crop,
std::vector<mmod_rect>& crop_rects std::vector<rectangle_type>& crop_rects
) )
{ {
DLIB_CASSERT(num_rows(img)*num_columns(img) != 0); DLIB_CASSERT(num_rows(img)*num_columns(img) != 0);
...@@ -202,12 +218,14 @@ namespace dlib ...@@ -202,12 +218,14 @@ namespace dlib
// map to crop // map to crop
rect.rect = tform(rect.rect); rect.rect = tform(rect.rect);
const double intersection = get_rect(crop).intersect(rect.rect).area();
// if the rect is at least partly in the crop // if the rect is at least partly in the crop
if (get_rect(crop).intersect(rect.rect).area() != 0) if (intersection != 0)
{ {
// set to ignore if not totally in the crop or if too small. // set to ignore if not totally in the crop or if too small.
if (!get_rect(crop).contains(rect.rect) || if (intersection / rect.rect.area() < min_object_coverage ||
((long)rect.rect.height() < min_object_length_long_dim && (long)rect.rect.width() < min_object_length_long_dim) || ((long)rect.rect.height() < min_object_length_long_dim && (long)rect.rect.width() < min_object_length_long_dim) ||
((long)rect.rect.height() < min_object_length_short_dim || (long)rect.rect.width() < min_object_length_short_dim)) ((long)rect.rect.height() < min_object_length_short_dim || (long)rect.rect.width() < min_object_length_short_dim))
{ {
rect.ignore = true; rect.ignore = true;
...@@ -230,10 +248,13 @@ namespace dlib ...@@ -230,10 +248,13 @@ namespace dlib
private: private:
template <typename image_type1> template <
typename image_type1,
typename rectangle_type
>
void make_crop_plan ( void make_crop_plan (
const image_type1& img, const image_type1& img,
const std::vector<mmod_rect>& rects, const std::vector<rectangle_type>& rects,
chip_details& crop_plan, chip_details& crop_plan,
bool& should_flip_crop bool& should_flip_crop
) )
...@@ -285,8 +306,9 @@ namespace dlib ...@@ -285,8 +306,9 @@ namespace dlib
crop_plan = chip_details(crop_rect, dims, angle); crop_plan = chip_details(crop_rect, dims, angle);
} }
template <typename rectangle_type>
bool has_non_ignored_box ( bool has_non_ignored_box (
const std::vector<mmod_rect>& rects const std::vector<rectangle_type>& rects
) const ) const
{ {
for (auto&& b : rects) for (auto&& b : rects)
...@@ -297,8 +319,9 @@ namespace dlib ...@@ -297,8 +319,9 @@ namespace dlib
return false; return false;
} }
template <typename rectangle_type>
size_t randomly_pick_rect ( size_t randomly_pick_rect (
const std::vector<mmod_rect>& rects const std::vector<rectangle_type>& rects
) )
{ {
DLIB_CASSERT(has_non_ignored_box(rects)); DLIB_CASSERT(has_non_ignored_box(rects));
......
...@@ -19,8 +19,8 @@ namespace dlib ...@@ -19,8 +19,8 @@ namespace dlib
This object is a tool for extracting random crops of objects from a set of This object is a tool for extracting random crops of objects from a set of
images. The crops are randomly jittered in scale, translation, and images. The crops are randomly jittered in scale, translation, and
rotation but more or less centered on objects specified by mmod_rect rotation but more or less centered on objects specified by mmod_rect
objects. objects (or other rectangle types with a compatible interface).
THREAD SAFETY THREAD SAFETY
It is safe for multiple threads to make concurrent calls to this object's It is safe for multiple threads to make concurrent calls to this object's
operator() methods. operator() methods.
...@@ -40,6 +40,7 @@ namespace dlib ...@@ -40,6 +40,7 @@ namespace dlib
- #get_max_object_size() == 0.7 - #get_max_object_size() == 0.7
- #get_background_crops_fraction() == 0.5 - #get_background_crops_fraction() == 0.5
- #get_translate_amount() == 0.1 - #get_translate_amount() == 0.1
- #get_min_object_coverage == 1.0
!*/ !*/
void set_seed ( void set_seed (
...@@ -152,7 +153,7 @@ namespace dlib ...@@ -152,7 +153,7 @@ namespace dlib
the longest edge of the object (i.e. either its height or width, the longest edge of the object (i.e. either its height or width,
whichever is longer) is at least #get_min_object_length_long_dim() pixels whichever is longer) is at least #get_min_object_length_long_dim() pixels
in length. When we say "object" here we are referring specifically to in length. When we say "object" here we are referring specifically to
the rectangle in the mmod_rect output by the cropper. the rectangle in the rectangle_type output by the cropper.
!*/ !*/
long get_min_object_length_short_dim ( long get_min_object_length_short_dim (
...@@ -163,7 +164,7 @@ namespace dlib ...@@ -163,7 +164,7 @@ namespace dlib
the shortest edge of the object (i.e. either its height or width, the shortest edge of the object (i.e. either its height or width,
whichever is shorter) is at least #get_min_object_length_short_dim() whichever is shorter) is at least #get_min_object_length_short_dim()
pixels in length. When we say "object" here we are referring pixels in length. When we say "object" here we are referring
specifically to the rectangle in the mmod_rect output by the cropper. specifically to the rectangle in the rectangle_type output by the cropper.
!*/ !*/
void set_min_object_size ( void set_min_object_size (
...@@ -199,15 +200,34 @@ namespace dlib ...@@ -199,15 +200,34 @@ namespace dlib
- #get_max_object_size() == value - #get_max_object_size() == value
!*/ !*/
double get_min_object_coverage (
) const;
/*!
ensures
- When a chip is extracted, any object that has less than get_min_object_coverage() fraction of its
total area contained within the crop will have its ignore field set to true.
!*/
void set_min_object_coverage (
double value
);
/*!
requires
- 0 < value <= 1
ensures
- #get_min_object_coverage() == value
!*/
template < template <
typename array_type typename array_type,
typename rectangle_type
> >
void append ( void append (
size_t num_crops, size_t num_crops,
const array_type& images, const array_type& images,
const std::vector<std::vector<mmod_rect>>& rects, const std::vector<std::vector<rectangle_type>>& rects,
array_type& crops, array_type& crops,
std::vector<std::vector<mmod_rect>>& crop_rects std::vector<std::vector<rectangle_type>>& crop_rects
); );
/*! /*!
requires requires
...@@ -218,6 +238,8 @@ namespace dlib ...@@ -218,6 +238,8 @@ namespace dlib
- array_type is a type with an interface compatible with dlib::array or - array_type is a type with an interface compatible with dlib::array or
std::vector and it must in turn contain image objects that implement the std::vector and it must in turn contain image objects that implement the
interface defined in dlib/image_processing/generic_image.h interface defined in dlib/image_processing/generic_image.h
- rectangle_type is a type with an interface compatible with mmod_rect, such
as yolo_rect.
ensures ensures
- Randomly extracts num_crops chips from images and appends them to the end - Randomly extracts num_crops chips from images and appends them to the end
of crops. We also copy the object metadata for each extracted crop and of crops. We also copy the object metadata for each extracted crop and
...@@ -230,14 +252,15 @@ namespace dlib ...@@ -230,14 +252,15 @@ namespace dlib
!*/ !*/
template < template <
typename array_type typename array_type,
typename rectangle_type
> >
void operator() ( void operator() (
size_t num_crops, size_t num_crops,
const array_type& images, const array_type& images,
const std::vector<std::vector<mmod_rect>>& rects, const std::vector<std::vector<rectangle_type>>& rects,
array_type& crops, array_type& crops,
std::vector<std::vector<mmod_rect>>& crop_rects std::vector<std::vector<rectangle_type>>& crop_rects
); );
/*! /*!
requires requires
...@@ -247,6 +270,8 @@ namespace dlib ...@@ -247,6 +270,8 @@ namespace dlib
- array_type is a type with an interface compatible with dlib::array or - array_type is a type with an interface compatible with dlib::array or
std::vector and it must in turn contain image objects that implement the std::vector and it must in turn contain image objects that implement the
interface defined in dlib/image_processing/generic_image.h interface defined in dlib/image_processing/generic_image.h
- rectangle_type is a type with an interface compatible with mmod_rect, such
as yolo_rect.
ensures ensures
- Randomly extracts num_crops chips from images. We also copy the object - Randomly extracts num_crops chips from images. We also copy the object
metadata for each extracted crop and store it into #crop_rects. In metadata for each extracted crop and store it into #crop_rects. In
...@@ -259,13 +284,14 @@ namespace dlib ...@@ -259,13 +284,14 @@ namespace dlib
template < template <
typename array_type, typename array_type,
typename image_type typename image_type,
typename rectangle_type
> >
void operator() ( void operator() (
const array_type& images, const array_type& images,
const std::vector<std::vector<mmod_rect>>& rects, const std::vector<std::vector<rectangle_type>>& rects,
image_type& crop, image_type& crop,
std::vector<mmod_rect>& crop_rects std::vector<rectangle_type>& crop_rects
); );
/*! /*!
requires requires
...@@ -277,6 +303,8 @@ namespace dlib ...@@ -277,6 +303,8 @@ namespace dlib
- array_type is a type with an interface compatible with dlib::array or - array_type is a type with an interface compatible with dlib::array or
std::vector and it must in turn contain image objects that implement the std::vector and it must in turn contain image objects that implement the
interface defined in dlib/image_processing/generic_image.h interface defined in dlib/image_processing/generic_image.h
- rectangle_type is a type with an interface compatible with mmod_rect, such
as yolo_rect.
ensures ensures
- Selects a random image and creates a random crop from it. Specifically, - Selects a random image and creates a random crop from it. Specifically,
we pick a random index IDX < images.size() and then execute we pick a random index IDX < images.size() and then execute
...@@ -285,13 +313,14 @@ namespace dlib ...@@ -285,13 +313,14 @@ namespace dlib
template < template <
typename image_type1, typename image_type1,
typename image_type2 typename image_type2,
typename rectangle_type
> >
void operator() ( void operator() (
const image_type1& img, const image_type1& img,
const std::vector<mmod_rect>& rects, const std::vector<rectangle_type>& rects,
image_type2& crop, image_type2& crop,
std::vector<mmod_rect>& crop_rects std::vector<rectangle_type>& crop_rects
); );
/*! /*!
requires requires
...@@ -300,9 +329,11 @@ namespace dlib ...@@ -300,9 +329,11 @@ namespace dlib
dlib/image_processing/generic_image.h dlib/image_processing/generic_image.h
- image_type2 == an image object that implements the interface defined in - image_type2 == an image object that implements the interface defined in
dlib/image_processing/generic_image.h dlib/image_processing/generic_image.h
- rectangle_type is a type with an interface compatible with mmod_rect, such
as yolo_rect.
ensures ensures
- Extracts a random crop from img and copies over the mmod_rect objects in - Extracts a random crop from img and copies over the rectangle_type objects
rects to #crop_rects if they are contained inside the crop. Moreover, in rects to #crop_rects if they are contained inside the crop. Moreover,
rectangles are marked as ignore if they aren't completely contained rectangles are marked as ignore if they aren't completely contained
inside the crop. inside the crop.
- #crop_rects.size() <= rects.size() - #crop_rects.size() <= rects.size()
...@@ -343,4 +374,3 @@ namespace dlib ...@@ -343,4 +374,3 @@ namespace dlib
#endif // DLIB_RaNDOM_CROPPER_ABSTRACT_H_ #endif // DLIB_RaNDOM_CROPPER_ABSTRACT_H_
...@@ -155,6 +155,7 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER) ...@@ -155,6 +155,7 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER)
add_example(dnn_instance_segmentation_train_ex) add_example(dnn_instance_segmentation_train_ex)
add_example(dnn_metric_learning_on_images_ex) add_example(dnn_metric_learning_on_images_ex)
add_gui_example(dnn_dcgan_train_ex) add_gui_example(dnn_dcgan_train_ex)
add_gui_example(dnn_yolo_train_ex)
endif() endif()
......
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This is an example illustrating the use of the deep learning tools from the dlib C++
Library. I'm assuming you have already read the dnn_introduction_ex.cpp, the
dnn_introduction2_ex.cpp and the dnn_introduction3_ex.cpp examples. In this example
program we are going to show how one can train a YOLO detector. In particular, we will train
the YOLOv3 model like the one introduced in this paper:
"YOLOv3: An Incremental Improvement" by Joseph Redmon and Ali Farhadi.
This example program will work with any imglab dataset, such as:
- faces: http://dlib.net/files/data/dlib_face_detection_dataset-2016-09-30.tar.gz
- vehicles: http://dlib.net/files/data/dlib_rear_end_vehicles_v1.tar
Just uncompress the dataset and give the directory containing the training.xml and testing.xml
files as an argument to this program.
*/
#include <dlib/cmd_line_parser.h>
#include <dlib/data_io.h>
#include <dlib/dnn.h>
#include <dlib/gui_widgets.h>
#include <dlib/image_io.h>
#include <tools/imglab/src/metadata_editor.h>
using namespace std;
using namespace dlib;
// In the darknet namespace we define:
// - the network architecture: DarkNet53 backbone and detection head for YOLO.
// - a helper function to setup the detector: change the number of classes, etc.
namespace darknet
{
// backbone tags
template <typename SUBNET> using btag8 = add_tag_layer<8008, SUBNET>;
template <typename SUBNET> using btag16 = add_tag_layer<8016, SUBNET>;
template <typename SUBNET> using bskip8 = add_skip_layer<btag8, SUBNET>;
template <typename SUBNET> using bskip16 = add_skip_layer<btag16, SUBNET>;
// neck tags
template <typename SUBNET> using ntag8 = add_tag_layer<6008, SUBNET>;
template <typename SUBNET> using ntag16 = add_tag_layer<6016, SUBNET>;
template <typename SUBNET> using ntag32 = add_tag_layer<6032, SUBNET>;
template <typename SUBNET> using nskip8 = add_skip_layer<ntag8, SUBNET>;
template <typename SUBNET> using nskip16 = add_skip_layer<ntag16, SUBNET>;
template <typename SUBNET> using nskip32 = add_skip_layer<ntag32, SUBNET>;
// head tags
template <typename SUBNET> using htag8 = add_tag_layer<7008, SUBNET>;
template <typename SUBNET> using htag16 = add_tag_layer<7016, SUBNET>;
template <typename SUBNET> using htag32 = add_tag_layer<7032, SUBNET>;
template <typename SUBNET> using hskip8 = add_skip_layer<htag8, SUBNET>;
template <typename SUBNET> using hskip16 = add_skip_layer<htag16, SUBNET>;
// yolo tags
template <typename SUBNET> using ytag8 = add_tag_layer<4008, SUBNET>;
template <typename SUBNET> using ytag16 = add_tag_layer<4016, SUBNET>;
template <typename SUBNET> using ytag32 = add_tag_layer<4032, SUBNET>;
template <template <typename> class ACT, template <typename> class BN>
struct def
{
template <long nf, long ks, int s, typename SUBNET>
using conblock = ACT<BN<add_layer<con_<nf, ks, ks, s, s, ks / 2, ks / 2>, SUBNET>>>;
template <long nf, typename SUBNET>
using residual = add_prev1<conblock<nf, 3, 1, conblock<nf / 2, 1, 1, tag1<SUBNET>>>>;
template <long nf, long factor, typename SUBNET>
using conblock5 = conblock<nf, 1, 1,
conblock<nf * factor, 3, 1,
conblock<nf, 1, 1,
conblock<nf * factor, 3, 1,
conblock<nf, 1, 1, SUBNET>>>>>;
template <typename SUBNET> using res_64 = residual<64, SUBNET>;
template <typename SUBNET> using res_128 = residual<128, SUBNET>;
template <typename SUBNET> using res_256 = residual<256, SUBNET>;
template <typename SUBNET> using res_512 = residual<512, SUBNET>;
template <typename SUBNET> using res_1024 = residual<1024, SUBNET>;
template <typename INPUT>
using backbone53 = repeat<4, res_1024, conblock<1024, 3, 2,
btag16<repeat<8, res_512, conblock<512, 3, 2,
btag8<repeat<8, res_256, conblock<256, 3, 2,
repeat<2, res_128, conblock<128, 3, 2,
res_64< conblock<64, 3, 2,
conblock<32, 3, 1,
INPUT>>>>>>>>>>>>>;
// This is the layer that will be passed to the loss layer to get the detections from the network.
// The main thing to pay attention to when defining the YOLO output layer is that it should be
// a tag layer, followed by a sigmoid layer and a 1x1 convolution. The tag layer should be unique
// in the whole network definition, as the loss layer will use it to get the outputs. The number of
// filters in the convolutional layer should be (1 + 4 + num_classes) * num_anchors at that output.
// The 1 corresponds to the objectness in the loss layer and the 4 to the bounding box coordinates.
template <long num_classes, long nf, template <typename> class YTAG, template <typename> class NTAG, typename SUBNET>
using yolo = YTAG<sig<con<3 * (num_classes + 5), 1, 1, 1, 1,
conblock<nf, 3, 1,
NTAG<conblock5<nf / 2, 2,
SUBNET>>>>>>;
template <long num_classes>
using yolov3 = yolo<num_classes, 256, ytag8, ntag8,
concat2<htag8, btag8,
htag8<upsample<2, conblock<128, 1, 1,
nskip16<
yolo<num_classes, 512, ytag16, ntag16,
concat2<htag16, btag16,
htag16<upsample<2, conblock<256, 1, 1,
nskip32<
yolo<num_classes, 1024, ytag32, ntag32,
backbone53<input_rgb_image>>>>>>>>>>>>>>;
};
using yolov3_train_type = loss_yolo<ytag8, ytag16, ytag32, def<leaky_relu, bn_con>::yolov3<80>>;
using yolov3_infer_type = loss_yolo<ytag8, ytag16, ytag32, def<leaky_relu, affine>::yolov3<80>>;
void setup_detector(yolov3_train_type& net, const yolo_options& options)
{
// remove bias from bn inputs
disable_duplicative_biases(net);
// setup leaky relus
visit_computational_layers(net, [](leaky_relu_& l) { l = leaky_relu_(0.1); });
// enlarge the batch normalization stats window
set_all_bn_running_stats_window_sizes(net, 1000);
// set the number of filters for detection layers (they are located after the tag and sig layers)
const long nfo1 = options.anchors.at(tag_id<ytag8>::id).size() * (options.labels.size() + 5);
const long nfo2 = options.anchors.at(tag_id<ytag16>::id).size() * (options.labels.size() + 5);
const long nfo3 = options.anchors.at(tag_id<ytag32>::id).size() * (options.labels.size() + 5);
layer<ytag8, 2>(net).layer_details().set_num_filters(nfo1);
layer<ytag16, 2>(net).layer_details().set_num_filters(nfo2);
layer<ytag32, 2>(net).layer_details().set_num_filters(nfo3);
}
}
// In this example, YOLO expects square images, and we choose to transform them by letterboxing them.
rectangle_transform preprocess_image(const matrix<rgb_pixel>& image, matrix<rgb_pixel>& output, const long image_size)
{
return rectangle_transform(inv(letterbox_image(image, output, image_size)));
}
// YOLO outputs the bounding boxes in the coordinate system of the input (letterboxed) image, so we need to convert them
// back to the original image.
void postprocess_detections(const rectangle_transform& tform, std::vector<yolo_rect>& detections)
{
for (auto& d : detections)
d.rect = tform(d.rect);
}
int main(const int argc, const char** argv)
try
{
command_line_parser parser;
parser.add_option("size", "image size for training (default: 416)", 1);
parser.add_option("learning-rate", "initial learning rate (default: 0.001)", 1);
parser.add_option("batch-size", "mini batch size (default: 8)", 1);
parser.add_option("burnin", "learning rate burnin steps (default: 1000)", 1);
parser.add_option("patience", "number of steps without progress (default: 10000)", 1);
parser.add_option("workers", "number of worker threads to load data (default: 4)", 1);
parser.add_option("gpus", "number of GPUs to run the training on (default: 1)", 1);
parser.add_option("test", "test the detector with a threshold (default: 0.01)", 1);
parser.add_option("visualize", "visualize data augmentation instead of training");
parser.add_option("map", "compute the mean average precision");
parser.add_option("anchors", "Do nothing but compute <arg1> anchor boxes using K-Means and print their shapes.", 1);
parser.set_group_name("Help Options");
parser.add_option("h", "alias of --help");
parser.add_option("help", "display this message and exit");
parser.parse(argc, argv);
if (parser.number_of_arguments() == 0 || parser.option("h") || parser.option("help"))
{
parser.print_options();
cout << "Give the path to a folder containing the training.xml file." << endl;
return 0;
}
const double learning_rate = get_option(parser, "learning-rate", 0.001);
const size_t patience = get_option(parser, "patience", 10000);
const size_t batch_size = get_option(parser, "batch-size", 8);
const size_t burnin = get_option(parser, "burnin", 1000);
const size_t image_size = get_option(parser, "size", 416);
const size_t num_workers = get_option(parser, "workers", 4);
const size_t num_gpus = get_option(parser, "gpus", 1);
const string data_directory = parser[0];
const string sync_file_name = "yolov3_sync";
image_dataset_metadata::dataset dataset;
image_dataset_metadata::load_image_dataset_metadata(dataset, data_directory + "/training.xml");
cout << "# images: " << dataset.images.size() << endl;
std::map<string, size_t> labels;
size_t num_objects = 0;
for (const auto& im : dataset.images)
{
for (const auto& b : im.boxes)
{
labels[b.label]++;
++num_objects;
}
}
cout << "# labels: " << labels.size() << endl;
yolo_options options;
color_mapper string_to_color;
for (const auto& label : labels)
{
cout << " - " << label.first << ": " << label.second;
cout << " (" << (100.0*label.second)/num_objects << "%)\n";
options.labels.push_back(label.first);
string_to_color(label.first);
}
// If the default anchor boxes don't fit your data well you should recompute them.
// Here's a simple way to do it using K-Means clustering. Note that the approach
// shown below is suboptimal, since it doesn't group the bounding boxes by size.
// Grouping the bounding boxes by size and computing the K-Means on each group
// would make more sense, since each stride of the network is meant to output
// boxes at a particular size, but that is very specific to the network architecture
// and the dataset itself.
if (parser.option("anchors"))
{
const auto num_clusers = std::stoul(parser.option("anchors").argument());
std::vector<dpoint> samples;
// First we need to rescale the bounding boxes to match the image size at training time.
for (const auto& image_info : dataset.images)
{
const auto scale = image_size / std::max<double>(image_info.width, image_info.height);
for (const auto& box : image_info.boxes)
{
dpoint sample(box.rect.width(), box.rect.height());
samples.push_back(sample*scale);
}
}
// Now we can compute K-Means clustering
randomize_samples(samples);
cout << "Computing anchors for " << samples.size() << " samples" << endl;
std::vector<dpoint> anchors;
pick_initial_centers(num_clusers, anchors, samples);
find_clusters_using_kmeans(samples, anchors);
std::sort(anchors.begin(), anchors.end(), [](const dpoint& a, const dpoint& b){ return prod(a) < prod(b); });
for (const dpoint& c : anchors)
cout << round(c(0)) << 'x' << round(c(1)) << endl;
// And check the average IoU of the newly computed anchor boxes and the training samples.
double average_iou = 0;
for (const dpoint& s : samples)
{
drectangle sample = centered_drect(dpoint(0, 0), s.x(), s.y());
double best_iou = 0;
for (const dpoint& a : anchors)
{
drectangle anchor = centered_drect(dpoint(0, 0), a.x(), a.y());
best_iou = std::max(best_iou, box_intersection_over_union(sample, anchor));
}
average_iou += best_iou;
}
cout << "Average IoU: " << average_iou / samples.size() << endl;
return EXIT_SUCCESS;
}
// When computing the objectness loss in YOLO, predictions that do not have an IoU
// with any ground truth box of at least options.iou_ignore_threshold, will be
// treated as not capable of detecting an object, an therefore incur loss.
// Similarly, predictions above this threshold are considered correct predictions
// by the loss. Typical settings for this threshold are in the range 0.5 to 0.7.
options.iou_ignore_threshold = 0.7;
// By setting this to a value < 1, we are telling the model to update all the predictions
// as long as the anchor box has an IoU > 0.2 with a ground truth.
options.iou_anchor_threshold = 0.2;
// These are the anchors computed on COCO dataset, presented in the YOLOv3 paper.
options.add_anchors<darknet::ytag8>({{10, 13}, {16, 30}, {33, 23}});
options.add_anchors<darknet::ytag16>({{30, 61}, {62, 45}, {59, 119}});
options.add_anchors<darknet::ytag32>({{116, 90}, {156, 198}, {373, 326}});
darknet::yolov3_train_type net(options);
darknet::setup_detector(net, options);
// The training process can be unstable at the beginning. For this reason, we exponentially
// increase the learning rate during the first burnin steps.
const matrix<double> learning_rate_schedule = learning_rate * pow(linspace(1e-12, 1, burnin), 4);
// In case we have several GPUs, we can tell the dnn_trainer to make use of them.
std::vector<int> gpus(num_gpus);
iota(gpus.begin(), gpus.end(), 0);
// We initialize the trainer here, as it will be used in several contexts, depending on the
// arguments passed the the program.
dnn_trainer<darknet::yolov3_train_type> trainer(net, sgd(0.0005, 0.9), gpus);
trainer.be_verbose();
trainer.set_mini_batch_size(batch_size);
trainer.set_learning_rate_schedule(learning_rate_schedule);
trainer.set_synchronization_file(sync_file_name, chrono::minutes(15));
cout << trainer;
// If the training has started and a synchronization file has already been saved to disk,
// we can re-run this program with the --test option and a confidence threshold to see
// how the training is going.
if (parser.option("test"))
{
if (!file_exists(sync_file_name))
{
cout << "Could not find file " << sync_file_name << endl;
return EXIT_FAILURE;
}
const double threshold = get_option(parser, "test", 0.01);
image_window win;
matrix<rgb_pixel> image, resized;
for (const auto& im : dataset.images)
{
win.clear_overlay();
load_image(image, data_directory + "/" + im.filename);
win.set_title(im.filename);
win.set_image(image);
const auto tform = preprocess_image(image, resized, image_size);
auto detections = net.process(resized, threshold);
postprocess_detections(tform, detections);
cout << "# detections: " << detections.size() << endl;
for (const auto& det : detections)
{
win.add_overlay(det.rect, string_to_color(det.label), det.label);
cout << det.label << ": " << det.rect << " " << det.detection_confidence << endl;
}
cin.get();
}
return EXIT_SUCCESS;
}
// If the training has started and a synchronization file has already been saved to disk,
// we can re-run this program with the --map option to compute the mean average precision
// on the test set.
if (parser.option("map"))
{
image_dataset_metadata::dataset dataset;
image_dataset_metadata::load_image_dataset_metadata(dataset, data_directory + "/testing.xml");
if (!file_exists(sync_file_name))
{
cout << "Could not find file " << sync_file_name << endl;
return EXIT_FAILURE;
}
matrix<rgb_pixel> image, resized;
std::map<std::string, std::vector<std::pair<double, bool>>> hits;
std::map<std::string, unsigned long> missing;
for (const auto& label : options.labels)
{
hits[label] = std::vector<std::pair<double, bool>>();
missing[label] = 0;
}
cout << "computing mean average precision for " << dataset.images.size() << " images..." << endl;
for (size_t i = 0; i < dataset.images.size(); ++i)
{
const auto& im = dataset.images[i];
load_image(image, data_directory + "/" + im.filename);
const auto tform = preprocess_image(image, resized, image_size);
auto dets = net.process(resized, 0.005);
postprocess_detections(tform, dets);
std::vector<bool> used(dets.size(), false);
// true positives: truths matched by detections
for (size_t t = 0; t < im.boxes.size(); ++t)
{
bool found_match = false;
for (size_t d = 0; d < dets.size(); ++d)
{
if (used[d])
continue;
if (im.boxes[t].label == dets[d].label &&
box_intersection_over_union(drectangle(im.boxes[t].rect), dets[d].rect) > 0.5)
{
used[d] = true;
found_match = true;
hits.at(dets[d].label).emplace_back(dets[d].detection_confidence, true);
break;
}
}
// false negatives: truths not matched
if (!found_match)
missing.at(im.boxes[t].label)++;
}
// false positives: detections not matched
for (size_t d = 0; d < dets.size(); ++d)
{
if (!used[d])
hits.at(dets[d].label).emplace_back(dets[d].detection_confidence, false);
}
cout << "progress: " << i << '/' << dataset.images.size() << "\t\t\t\r" << flush;
}
double map = 0;
for (auto& item : hits)
{
std::sort(item.second.rbegin(), item.second.rend());
const double ap = average_precision(item.second, missing[item.first]);
cout << rpad(item.first + ": ", 16, " ") << ap * 100 << '%' << endl;
map += ap;
}
cout << rpad(string("mAP: "), 16, " ") << map / hits.size() * 100 << '%' << endl;
return EXIT_SUCCESS;
}
// Create some data loaders which will load the data, and perform som data augmentation.
dlib::pipe<std::pair<matrix<rgb_pixel>, std::vector<yolo_rect>>> train_data(1000);
const auto loader = [&dataset, &data_directory, &train_data, &image_size](time_t seed)
{
dlib::rand rnd(time(nullptr) + seed);
matrix<rgb_pixel> image, rotated;
std::pair<matrix<rgb_pixel>, std::vector<yolo_rect>> temp;
random_cropper cropper;
cropper.set_seed(time(nullptr) + seed);
cropper.set_chip_dims(image_size, image_size);
cropper.set_max_object_size(0.9);
cropper.set_min_object_size(10, 10);
cropper.set_max_rotation_degrees(10);
cropper.set_translate_amount(0.5);
cropper.set_randomly_flip(true);
cropper.set_background_crops_fraction(0);
cropper.set_min_object_coverage(0.8);
while (train_data.is_enabled())
{
const auto idx = rnd.get_random_32bit_number() % dataset.images.size();
load_image(image, data_directory + "/" + dataset.images[idx].filename);
for (const auto& box : dataset.images[idx].boxes)
temp.second.emplace_back(box.rect, 1, box.label);
// We alternate between augmenting the full image and random cropping
if (rnd.get_random_double() > 0.5)
{
rectangle_transform tform = rotate_image(
image,
rotated,
rnd.get_double_in_range(-5 * pi / 180, 5 * pi / 180),
interpolate_bilinear());
for (auto& box : temp.second)
box.rect = tform(box.rect);
tform = letterbox_image(rotated, temp.first, image_size);
for (auto& box : temp.second)
box.rect = tform(box.rect);
if (rnd.get_random_double() > 0.5)
{
tform = flip_image_left_right(temp.first);
for (auto& box : temp.second)
box.rect = tform(box.rect);
}
}
else
{
std::vector<yolo_rect> boxes = temp.second;
cropper(image, boxes, temp.first, temp.second);
}
disturb_colors(temp.first, rnd);
train_data.enqueue(temp);
}
};
std::vector<thread> data_loaders;
for (size_t i = 0; i < num_workers; ++i)
data_loaders.emplace_back([loader, i]() { loader(i + 1); });
// It is always a good idea to visualize the training samples. By passing the --visualize
// flag, we can see the training samples that will be fed to the dnn_trainer.
if (parser.option("visualize"))
{
image_window win;
while (true)
{
std::pair<matrix<rgb_pixel>, std::vector<yolo_rect>> temp;
train_data.dequeue(temp);
win.clear_overlay();
win.set_image(temp.first);
for (const auto& r : temp.second)
{
auto color = string_to_color(r.label);
// make semi-transparent and cross-out the ignored boxes
if (r.ignore)
{
color.alpha = 128;
win.add_overlay(r.rect.tl_corner(), r.rect.br_corner(), color);
win.add_overlay(r.rect.tr_corner(), r.rect.bl_corner(), color);
}
win.add_overlay(r.rect, color, r.label);
}
cout << "Press enter to visualize the next training sample.";
cin.get();
}
}
std::vector<matrix<rgb_pixel>> images;
std::vector<std::vector<yolo_rect>> bboxes;
// The main training loop, that we will reuse for the warmup and the rest of the training.
const auto train = [&images, &bboxes, &train_data, &trainer]()
{
images.clear();
bboxes.clear();
pair<matrix<rgb_pixel>, std::vector<yolo_rect>> temp;
while (images.size() < trainer.get_mini_batch_size())
{
train_data.dequeue(temp);
images.push_back(move(temp.first));
bboxes.push_back(move(temp.second));
}
trainer.train_one_step(images, bboxes);
};
cout << "training started with " << burnin << " burn-in steps" << endl;
while (trainer.get_train_one_step_calls() < burnin)
train();
cout << "burn-in finished" << endl;
trainer.get_net();
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(learning_rate * 1e-3);
trainer.set_learning_rate_shrink_factor(0.1);
trainer.set_iterations_without_progress_threshold(patience);
cout << trainer << endl;
while (trainer.get_learning_rate() >= trainer.get_min_learning_rate())
train();
cout << "training done" << endl;
trainer.get_net();
train_data.disable();
for (auto& worker : data_loaders)
worker.join();
serialize("yolov3.dnn") << net;
return EXIT_SUCCESS;
}
catch (const std::exception& e)
{
cout << e.what() << endl;
return EXIT_FAILURE;
}
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
<name>Testing faces</name> <name>Testing faces</name>
<comment>These are images from the PASCAL VOC 2011 dataset.</comment> <comment>These are images from the PASCAL VOC 2011 dataset.</comment>
<images> <images>
<image file='2008_002470.jpg'> <image file='2008_002470.jpg' width='500' height='332'>
<box top='181' left='274' width='52' height='53'/> <box top='181' left='274' width='52' height='53'/>
<box top='156' left='55' width='44' height='44'/> <box top='156' left='55' width='44' height='44'/>
<box top='166' left='146' width='37' height='37'/> <box top='166' left='146' width='37' height='37'/>
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
<box top='74' left='233' width='44' height='44'/> <box top='74' left='233' width='44' height='44'/>
<box top='86' left='178' width='37' height='37'/> <box top='86' left='178' width='37' height='37'/>
</image> </image>
<image file='2008_002506.jpg'> <image file='2008_002506.jpg' width='500' height='375'>
<box top='78' left='329' width='109' height='109'/> <box top='78' left='329' width='109' height='109'/>
<box top='95' left='224' width='91' height='91'/> <box top='95' left='224' width='91' height='91'/>
<box top='65' left='125' width='90' height='91'/> <box top='65' left='125' width='90' height='91'/>
</image> </image>
<image file='2008_004176.jpg'> <image file='2008_004176.jpg' width='480' height='438'>
<box top='230' left='206' width='37' height='37'/> <box top='230' left='206' width='37' height='37'/>
<box top='118' left='162' width='37' height='37'/> <box top='118' left='162' width='37' height='37'/>
<box top='82' left='190' width='37' height='37'/> <box top='82' left='190' width='37' height='37'/>
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
<box top='86' left='110' width='37' height='37'/> <box top='86' left='110' width='37' height='37'/>
<box top='102' left='282' width='37' height='37'/> <box top='102' left='282' width='37' height='37'/>
</image> </image>
<image file='2008_007676.jpg'> <image file='2008_007676.jpg' width='500' height='334'>
<box top='62' left='226' width='37' height='37'/> <box top='62' left='226' width='37' height='37'/>
<box top='113' left='194' width='44' height='44'/> <box top='113' left='194' width='44' height='44'/>
<box top='130' left='262' width='37' height='37'/> <box top='130' left='262' width='37' height='37'/>
...@@ -35,9 +35,9 @@ ...@@ -35,9 +35,9 @@
<box top='141' left='107' width='52' height='53'/> <box top='141' left='107' width='52' height='53'/>
<box top='84' left='137' width='44' height='44'/> <box top='84' left='137' width='44' height='44'/>
</image> </image>
<image file='2009_004587.jpg'> <image file='2009_004587.jpg' width='400' height='500'>
<box top='46' left='154' width='75' height='76'/> <box top='46' left='154' width='75' height='76'/>
<box top='280' left='266' width='63' height='63'/> <box top='280' left='266' width='63' height='63'/>
</image> </image>
</images> </images>
</dataset> </dataset>
\ No newline at end of file
...@@ -5,10 +5,9 @@ ...@@ -5,10 +5,9 @@
<comment>These are images from the PASCAL VOC 2011 dataset. <comment>These are images from the PASCAL VOC 2011 dataset.
The face landmarks are from dlib's shape_predictor_68_face_landmarks.dat The face landmarks are from dlib's shape_predictor_68_face_landmarks.dat
landmarking model. The model uses the 68 landmark scheme used by the iBUG landmarking model. The model uses the 68 landmark scheme used by the iBUG
300-W dataset. 300-W dataset.</comment>
</comment>
<images> <images>
<image file='2008_002470.jpg'> <image file='2008_002470.jpg' width='500' height='332'>
<box top='181' left='274' width='52' height='53'> <box top='181' left='274' width='52' height='53'>
<part name='00' x='277' y='194'/> <part name='00' x='277' y='194'/>
<part name='01' x='278' y='200'/> <part name='01' x='278' y='200'/>
...@@ -430,7 +429,7 @@ ...@@ -430,7 +429,7 @@
<part name='67' x='196' y='112'/> <part name='67' x='196' y='112'/>
</box> </box>
</image> </image>
<image file='2008_002506.jpg'> <image file='2008_002506.jpg' width='500' height='375'>
<box top='78' left='329' width='109' height='109'> <box top='78' left='329' width='109' height='109'>
<part name='00' x='342' y='134'/> <part name='00' x='342' y='134'/>
<part name='01' x='345' y='145'/> <part name='01' x='345' y='145'/>
...@@ -642,7 +641,7 @@ ...@@ -642,7 +641,7 @@
<part name='67' x='163' y='133'/> <part name='67' x='163' y='133'/>
</box> </box>
</image> </image>
<image file='2008_004176.jpg'> <image file='2008_004176.jpg' width='480' height='438'>
<box top='230' left='206' width='37' height='37'> <box top='230' left='206' width='37' height='37'>
<part name='00' x='206' y='241'/> <part name='00' x='206' y='241'/>
<part name='01' x='206' y='245'/> <part name='01' x='206' y='245'/>
...@@ -1134,7 +1133,7 @@ ...@@ -1134,7 +1133,7 @@
<part name='67' x='294' y='126'/> <part name='67' x='294' y='126'/>
</box> </box>
</image> </image>
<image file='2008_007676.jpg'> <image file='2008_007676.jpg' width='500' height='334'>
<box top='62' left='226' width='37' height='37'> <box top='62' left='226' width='37' height='37'>
<part name='00' x='223' y='72'/> <part name='00' x='223' y='72'/>
<part name='01' x='224' y='77'/> <part name='01' x='224' y='77'/>
...@@ -1626,7 +1625,7 @@ ...@@ -1626,7 +1625,7 @@
<part name='67' x='160' y='115'/> <part name='67' x='160' y='115'/>
</box> </box>
</image> </image>
<image file='2009_004587.jpg'> <image file='2009_004587.jpg' width='400' height='500'>
<box top='46' left='154' width='75' height='76'> <box top='46' left='154' width='75' height='76'>
<part name='00' x='147' y='74'/> <part name='00' x='147' y='74'/>
<part name='01' x='147' y='84'/> <part name='01' x='147' y='84'/>
...@@ -1769,4 +1768,4 @@ ...@@ -1769,4 +1768,4 @@
</box> </box>
</image> </image>
</images> </images>
</dataset> </dataset>
\ No newline at end of file
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
<name>Training faces</name> <name>Training faces</name>
<comment>These are images from the PASCAL VOC 2011 dataset.</comment> <comment>These are images from the PASCAL VOC 2011 dataset.</comment>
<images> <images>
<image file='2007_007763.jpg'> <image file='2007_007763.jpg' width='500' height='375'>
<box top='90' left='194' width='37' height='37'/> <box top='90' left='194' width='37' height='37'/>
<box top='114' left='158' width='37' height='37'/> <box top='114' left='158' width='37' height='37'/>
<box top='89' left='381' width='45' height='44'/> <box top='89' left='381' width='45' height='44'/>
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
<box top='86' left='294' width='37' height='37'/> <box top='86' left='294' width='37' height='37'/>
<box top='233' left='309' width='45' height='44'/> <box top='233' left='309' width='45' height='44'/>
</image> </image>
<image file='2008_002079.jpg'> <image file='2008_002079.jpg' width='500' height='375'>
<box top='166' left='407' width='37' height='37'/> <box top='166' left='407' width='37' height='37'/>
<box top='134' left='122' width='37' height='37'/> <box top='134' left='122' width='37' height='37'/>
<box top='138' left='346' width='37' height='37'/> <box top='138' left='346' width='37' height='37'/>
...@@ -21,11 +21,11 @@ ...@@ -21,11 +21,11 @@
<box top='134' left='62' width='37' height='37'/> <box top='134' left='62' width='37' height='37'/>
<box top='194' left='41' width='44' height='44'/> <box top='194' left='41' width='44' height='44'/>
</image> </image>
<image file='2008_001009.jpg'> <image file='2008_001009.jpg' width='360' height='480'>
<box top='79' left='145' width='76' height='76'/> <box top='79' left='145' width='76' height='76'/>
<box top='214' left='125' width='90' height='91'/> <box top='214' left='125' width='90' height='91'/>
</image> </image>
<image file='2008_001322.jpg'> <image file='2008_001322.jpg' width='500' height='375'>
<box top='162' left='104' width='76' height='76'/> <box top='162' left='104' width='76' height='76'/>
<box top='218' left='232' width='63' height='63'/> <box top='218' left='232' width='63' height='63'/>
<box top='155' left='344' width='90' height='90'/> <box top='155' left='344' width='90' height='90'/>
......
...@@ -5,10 +5,9 @@ ...@@ -5,10 +5,9 @@
<comment>These are images from the PASCAL VOC 2011 dataset. <comment>These are images from the PASCAL VOC 2011 dataset.
The face landmarks are from dlib's shape_predictor_68_face_landmarks.dat The face landmarks are from dlib's shape_predictor_68_face_landmarks.dat
landmarking model. The model uses the 68 landmark scheme used by the iBUG landmarking model. The model uses the 68 landmark scheme used by the iBUG
300-W dataset. 300-W dataset.</comment>
</comment>
<images> <images>
<image file='2007_007763.jpg'> <image file='2007_007763.jpg' width='500' height='375'>
<box top='90' left='194' width='37' height='37'> <box top='90' left='194' width='37' height='37'>
<part name='00' x='201' y='107'/> <part name='00' x='201' y='107'/>
<part name='01' x='201' y='110'/> <part name='01' x='201' y='110'/>
...@@ -500,7 +499,7 @@ ...@@ -500,7 +499,7 @@
<part name='67' x='323' y='267'/> <part name='67' x='323' y='267'/>
</box> </box>
</image> </image>
<image file='2008_002079.jpg'> <image file='2008_002079.jpg' width='500' height='375'>
<box top='166' left='406' width='37' height='37'> <box top='166' left='406' width='37' height='37'>
<part name='00' x='412' y='179'/> <part name='00' x='412' y='179'/>
<part name='01' x='411' y='183'/> <part name='01' x='411' y='183'/>
...@@ -922,7 +921,7 @@ ...@@ -922,7 +921,7 @@
<part name='67' x='68' y='227'/> <part name='67' x='68' y='227'/>
</box> </box>
</image> </image>
<image file='2008_001009.jpg'> <image file='2008_001009.jpg' width='360' height='480'>
<box top='79' left='145' width='76' height='76'> <box top='79' left='145' width='76' height='76'>
<part name='00' x='145' y='115'/> <part name='00' x='145' y='115'/>
<part name='01' x='148' y='124'/> <part name='01' x='148' y='124'/>
...@@ -1064,7 +1063,7 @@ ...@@ -1064,7 +1063,7 @@
<part name='67' x='168' y='280'/> <part name='67' x='168' y='280'/>
</box> </box>
</image> </image>
<image file='2008_001322.jpg'> <image file='2008_001322.jpg' width='500' height='375'>
<box top='162' left='104' width='76' height='76'> <box top='162' left='104' width='76' height='76'>
<part name='00' x='106' y='183'/> <part name='00' x='106' y='183'/>
<part name='01' x='106' y='193'/> <part name='01' x='106' y='193'/>
...@@ -1277,4 +1276,4 @@ ...@@ -1277,4 +1276,4 @@
</box> </box>
</image> </image>
</images> </images>
</dataset> </dataset>
\ No newline at end of file
...@@ -247,6 +247,8 @@ int cluster_dataset( ...@@ -247,6 +247,8 @@ int cluster_dataset(
{ {
idata[i].first = std::numeric_limits<double>::infinity(); idata[i].first = std::numeric_limits<double>::infinity();
idata[i].second.filename = data.images[i].filename; idata[i].second.filename = data.images[i].filename;
idata[i].second.width = data.images[i].width;
idata[i].second.height = data.images[i].height;
if (!has_non_ignored_boxes(data.images[i])) if (!has_non_ignored_boxes(data.images[i]))
continue; continue;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <dlib/dir_nav.h> #include <dlib/dir_nav.h>
const char* VERSION = "1.17"; const char* VERSION = "1.18";
...@@ -332,6 +332,8 @@ void rotate_dataset(const command_line_parser& parser) ...@@ -332,6 +332,8 @@ void rotate_dataset(const command_line_parser& parser)
load_image(img, metadata.images[i].filename); load_image(img, metadata.images[i].filename);
const point_transform_affine tran = rotate_image(img, temp, angle*pi/180); const point_transform_affine tran = rotate_image(img, temp, angle*pi/180);
metadata.images[i].width = temp.nc();
metadata.images[i].height = temp.nr();
if (parser.option("jpg")) if (parser.option("jpg"))
{ {
filename = to_jpg_name(filename); filename = to_jpg_name(filename);
...@@ -359,6 +361,32 @@ void rotate_dataset(const command_line_parser& parser) ...@@ -359,6 +361,32 @@ void rotate_dataset(const command_line_parser& parser)
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void add_width_and_height_metadata(const command_line_parser& parser)
{
for (unsigned long i = 0; i < parser.number_of_arguments(); ++i) {
image_dataset_metadata::dataset metadata;
const string datasource = parser[i];
load_image_dataset_metadata(metadata,datasource);
// Set the current directory to be the one that contains the
// metadata file. We do this because the file might contain
// file paths which are relative to this folder.
set_current_dir(get_parent_directory(file(datasource)));
parallel_for(0, metadata.images.size(), [&](long i)
{
array2d<rgb_pixel> img;
load_image(img, metadata.images[i].filename);
metadata.images[i].width = img.nc();
metadata.images[i].height = img.nr();
});
save_image_dataset_metadata(metadata, datasource);
}
}
// ----------------------------------------------------------------------------------------
int resample_dataset(const command_line_parser& parser) int resample_dataset(const command_line_parser& parser)
{ {
if (parser.number_of_arguments() != 1) if (parser.number_of_arguments() != 1)
...@@ -447,6 +475,8 @@ int resample_dataset(const command_line_parser& parser) ...@@ -447,6 +475,8 @@ int resample_dataset(const command_line_parser& parser)
std::ostringstream sout; std::ostringstream sout;
sout << hex << murmur_hash3_128bit(&chip[0][0], chip.size()*sizeof(chip[0][0])).second; sout << hex << murmur_hash3_128bit(&chip[0][0], chip.size()*sizeof(chip[0][0])).second;
dimg.filename = data.images[i].filename + "_RESAMPLED_"+sout.str()+".png"; dimg.filename = data.images[i].filename + "_RESAMPLED_"+sout.str()+".png";
dimg.width = chip.nc();
dimg.height = chip.nr();
if (parser.option("jpg")) if (parser.option("jpg"))
{ {
...@@ -588,6 +618,8 @@ int main(int argc, char** argv) ...@@ -588,6 +618,8 @@ int main(int argc, char** argv)
"The parts are instead simply mirrored to the flipped dataset.", 1); "The parts are instead simply mirrored to the flipped dataset.", 1);
parser.add_option("rotate", "Read an XML image dataset and output a copy that is rotated counter clockwise by <arg> degrees. " parser.add_option("rotate", "Read an XML image dataset and output a copy that is rotated counter clockwise by <arg> degrees. "
"The output is saved to an XML file prefixed with rotated_<arg>.",1); "The output is saved to an XML file prefixed with rotated_<arg>.",1);
parser.add_option("add-width-height-metadata", "Open the given xml files and set the width and height image metadata fields "
"for every image. This involves loading each image to find these values.");
parser.add_option("cluster", "Cluster all the objects in an XML file into <arg> different clusters (pass 0 to find automatically) and save " parser.add_option("cluster", "Cluster all the objects in an XML file into <arg> different clusters (pass 0 to find automatically) and save "
"the results as cluster_###.xml and cluster_###.jpg files.",1); "the results as cluster_###.xml and cluster_###.jpg files.",1);
parser.add_option("ignore", "Mark boxes labeled as <arg> as ignored. The resulting XML file is output as a separate file and the original is not modified.",1); parser.add_option("ignore", "Mark boxes labeled as <arg> as ignored. The resulting XML file is output as a separate file and the original is not modified.",1);
...@@ -612,7 +644,7 @@ int main(int argc, char** argv) ...@@ -612,7 +644,7 @@ int main(int argc, char** argv)
const char* singles[] = {"h","c","r","l","files","convert","parts","rmdiff", "rmtrunc", "rmdupes", "seed", "shuffle", "split", "add", const char* singles[] = {"h","c","r","l","files","convert","parts","rmdiff", "rmtrunc", "rmdupes", "seed", "shuffle", "split", "add",
"flip-basic", "flip", "rotate", "tile", "size", "cluster", "resample", "min-object-size", "rmempty", "flip-basic", "flip", "rotate", "tile", "size", "cluster", "resample", "min-object-size", "rmempty",
"crop-size", "cropped-object-size", "rmlabel", "rm-other-labels", "rm-if-overlaps", "sort-num-objects", "crop-size", "cropped-object-size", "rmlabel", "rm-other-labels", "rm-if-overlaps", "sort-num-objects",
"one-object-per-image", "jpg", "rmignore", "sort", "split-train-test", "box-images"}; "one-object-per-image", "jpg", "rmignore", "sort", "split-train-test", "box-images", "add-width-height-metadata"};
parser.check_one_time_options(singles); parser.check_one_time_options(singles);
const char* c_sub_ops[] = {"r", "convert"}; const char* c_sub_ops[] = {"r", "convert"};
parser.check_sub_options("c", c_sub_ops); parser.check_sub_options("c", c_sub_ops);
...@@ -637,6 +669,7 @@ int main(int argc, char** argv) ...@@ -637,6 +669,7 @@ int main(int argc, char** argv)
parser.check_incompatible_options("c", "flip-basic"); parser.check_incompatible_options("c", "flip-basic");
parser.check_incompatible_options("flip", "flip-basic"); parser.check_incompatible_options("flip", "flip-basic");
parser.check_incompatible_options("c", "rotate"); parser.check_incompatible_options("c", "rotate");
parser.check_incompatible_options("c", "add-width-height-metadata");
parser.check_incompatible_options("c", "rename"); parser.check_incompatible_options("c", "rename");
parser.check_incompatible_options("c", "ignore"); parser.check_incompatible_options("c", "ignore");
parser.check_incompatible_options("c", "parts"); parser.check_incompatible_options("c", "parts");
...@@ -650,6 +683,7 @@ int main(int argc, char** argv) ...@@ -650,6 +683,7 @@ int main(int argc, char** argv)
parser.check_incompatible_options("l", "flip"); parser.check_incompatible_options("l", "flip");
parser.check_incompatible_options("l", "flip-basic"); parser.check_incompatible_options("l", "flip-basic");
parser.check_incompatible_options("l", "rotate"); parser.check_incompatible_options("l", "rotate");
parser.check_incompatible_options("l", "add-width-height-metadata");
parser.check_incompatible_options("files", "rename"); parser.check_incompatible_options("files", "rename");
parser.check_incompatible_options("files", "ignore"); parser.check_incompatible_options("files", "ignore");
parser.check_incompatible_options("files", "add"); parser.check_incompatible_options("files", "add");
...@@ -657,22 +691,27 @@ int main(int argc, char** argv) ...@@ -657,22 +691,27 @@ int main(int argc, char** argv)
parser.check_incompatible_options("files", "flip"); parser.check_incompatible_options("files", "flip");
parser.check_incompatible_options("files", "flip-basic"); parser.check_incompatible_options("files", "flip-basic");
parser.check_incompatible_options("files", "rotate"); parser.check_incompatible_options("files", "rotate");
parser.check_incompatible_options("files", "add-width-height-metadata");
parser.check_incompatible_options("add", "flip"); parser.check_incompatible_options("add", "flip");
parser.check_incompatible_options("add", "flip-basic"); parser.check_incompatible_options("add", "flip-basic");
parser.check_incompatible_options("add", "rotate"); parser.check_incompatible_options("add", "rotate");
parser.check_incompatible_options("add", "add-width-height-metadata");
parser.check_incompatible_options("add", "tile"); parser.check_incompatible_options("add", "tile");
parser.check_incompatible_options("flip", "tile"); parser.check_incompatible_options("flip", "tile");
parser.check_incompatible_options("flip-basic", "tile"); parser.check_incompatible_options("flip-basic", "tile");
parser.check_incompatible_options("rotate", "tile"); parser.check_incompatible_options("rotate", "tile");
parser.check_incompatible_options("add-width-height-metadata", "tile");
parser.check_incompatible_options("cluster", "tile"); parser.check_incompatible_options("cluster", "tile");
parser.check_incompatible_options("resample", "tile"); parser.check_incompatible_options("resample", "tile");
parser.check_incompatible_options("flip", "cluster"); parser.check_incompatible_options("flip", "cluster");
parser.check_incompatible_options("flip-basic", "cluster"); parser.check_incompatible_options("flip-basic", "cluster");
parser.check_incompatible_options("rotate", "cluster"); parser.check_incompatible_options("rotate", "cluster");
parser.check_incompatible_options("add-width-height-metadata", "cluster");
parser.check_incompatible_options("add", "cluster"); parser.check_incompatible_options("add", "cluster");
parser.check_incompatible_options("flip", "resample"); parser.check_incompatible_options("flip", "resample");
parser.check_incompatible_options("flip-basic", "resample"); parser.check_incompatible_options("flip-basic", "resample");
parser.check_incompatible_options("rotate", "resample"); parser.check_incompatible_options("rotate", "resample");
parser.check_incompatible_options("add-width-height-metadata", "resample");
parser.check_incompatible_options("add", "resample"); parser.check_incompatible_options("add", "resample");
parser.check_incompatible_options("shuffle", "tile"); parser.check_incompatible_options("shuffle", "tile");
parser.check_incompatible_options("sort-num-objects", "tile"); parser.check_incompatible_options("sort-num-objects", "tile");
...@@ -738,6 +777,12 @@ int main(int argc, char** argv) ...@@ -738,6 +777,12 @@ int main(int argc, char** argv)
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
if (parser.option("add-width-height-metadata"))
{
add_width_and_height_metadata(parser);
return EXIT_SUCCESS;
}
if (parser.option("v")) if (parser.option("v"))
{ {
cout << "imglab v" << VERSION cout << "imglab v" << VERSION
......
...@@ -224,7 +224,12 @@ void propagate_boxes( ...@@ -224,7 +224,12 @@ void propagate_boxes(
array2d<rgb_pixel> img1, img2; array2d<rgb_pixel> img1, img2;
dlib::load_image(img1, data.images[prev].filename); dlib::load_image(img1, data.images[prev].filename);
data.images[prev].width = img1.nc();
data.images[prev].height = img1.nr();
dlib::load_image(img2, data.images[next].filename); dlib::load_image(img2, data.images[next].filename);
data.images[next].width = img2.nc();
data.images[next].height = img2.nr();
for (unsigned long i = 0; i < data.images[prev].boxes.size(); ++i) for (unsigned long i = 0; i < data.images[prev].boxes.size(); ++i)
{ {
correlation_tracker tracker; correlation_tracker tracker;
...@@ -513,6 +518,8 @@ load_image( ...@@ -513,6 +518,8 @@ load_image(
try try
{ {
dlib::load_image(img, metadata.images[idx].filename); dlib::load_image(img, metadata.images[idx].filename);
metadata.images[idx].width = img.nc();
metadata.images[idx].height = img.nr();
set_title(metadata.name + " #"+cast_to_string(idx)+": " +metadata.images[idx].filename); set_title(metadata.name + " #"+cast_to_string(idx)+": " +metadata.images[idx].filename);
} }
catch (exception& e) catch (exception& e)
...@@ -543,6 +550,8 @@ load_image_and_set_size( ...@@ -543,6 +550,8 @@ load_image_and_set_size(
try try
{ {
dlib::load_image(img, metadata.images[idx].filename); dlib::load_image(img, metadata.images[idx].filename);
metadata.images[idx].width = img.nc();
metadata.images[idx].height = img.nr();
set_title(metadata.name + " #"+cast_to_string(idx)+": " +metadata.images[idx].filename); set_title(metadata.name + " #"+cast_to_string(idx)+": " +metadata.images[idx].filename);
} }
catch (exception& e) catch (exception& e)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment