Unverified Commit 1b7053fe authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add focal gamma to loss_multibinary_log (#2546)

* Add focal gamma to loss_multibinary_log

* update release notes
parent f1a29f35
...@@ -770,6 +770,13 @@ namespace dlib ...@@ -770,6 +770,13 @@ namespace dlib
typedef std::vector<float> training_label_type; typedef std::vector<float> training_label_type;
typedef std::vector<float> output_label_type; typedef std::vector<float> output_label_type;
loss_multibinary_log_() = default;
loss_multibinary_log_(double gamma) : gamma(gamma)
{
DLIB_CASSERT(gamma >= 0);
}
template < template <
typename SUB_TYPE, typename SUB_TYPE,
typename label_iterator typename label_iterator
...@@ -842,43 +849,53 @@ namespace dlib ...@@ -842,43 +849,53 @@ namespace dlib
if (y > 0) if (y > 0)
{ {
const float temp = log1pexp(-out_data[idx]); const float temp = log1pexp(-out_data[idx]);
const float focus = std::pow(1 - g[idx], gamma);
loss += y * scale * temp; loss += y * scale * temp;
g[idx] = y * scale * (g[idx] - 1); g[idx] = y * scale * focus * (g[idx] * (gamma * temp + 1) - 1);
} }
else else
{ {
const float temp = -(-out_data[idx] - log1pexp(-out_data[idx])); const float temp = -(-out_data[idx] - log1pexp(-out_data[idx]));
const float focus = std::pow(g[idx], gamma);
loss += -y * scale * temp; loss += -y * scale * temp;
g[idx] = -y * scale * g[idx]; g[idx] = -y * scale * focus * g[idx] * (gamma * temp + 1);
} }
} }
} }
return loss; return loss;
} }
friend void serialize(const loss_multibinary_log_&, std::ostream& out) double get_gamma () const { return gamma; }
friend void serialize(const loss_multibinary_log_& item, std::ostream& out)
{ {
serialize("loss_multibinary_log_", out); serialize("loss_multibinary_log_2", out);
serialize(item.gamma, out);
} }
friend void deserialize(loss_multibinary_log_&, std::istream& in) friend void deserialize(loss_multibinary_log_& item, std::istream& in)
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "loss_multibinary_log_") if (version != "loss_multibinary_log_" || version != "loss_multibinary_log_2")
throw serialization_error("Unexpected version found while deserializing dlib::loss_multibinary_log_."); throw serialization_error("Unexpected version found while deserializing dlib::loss_multibinary_log_.");
if (version == "loss_multibinary_log_2")
deserialize(item.gamma, in);
} }
friend std::ostream& operator<<(std::ostream& out, const loss_multibinary_log_& ) friend std::ostream& operator<<(std::ostream& out, const loss_multibinary_log_& item)
{ {
out << "loss_multibinary_log"; out << "loss_multibinary_log (gamma=" << item.gamma << ")";
return out; return out;
} }
friend void to_xml(const loss_multibinary_log_& /*item*/, std::ostream& out) friend void to_xml(const loss_multibinary_log_& item, std::ostream& out)
{ {
out << "<loss_multibinary_log/>"; out << "<loss_multibinary_log gamma='" << item.gamma << "'/>";
} }
private:
double gamma = 0;
}; };
template <typename SUBNET> template <typename SUBNET>
......
...@@ -718,6 +718,15 @@ namespace dlib ...@@ -718,6 +718,15 @@ namespace dlib
To be more specific, this object contains a sigmoid layer followed by a To be more specific, this object contains a sigmoid layer followed by a
cross-entropy layer. cross-entropy layer.
Additionaly, this layer also contains a focusing parameter gamma, which
acts as a modulating factor to the cross-entropy layer by reducing the
relative loss for well-classified examples, and focusing on the difficult
ones. This gamma parameter makes this layer behave like the Focal loss,
presented in the paper:
Focal Loss for Dense Object Detection
by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár
(https://arxiv.org/abs/1708.02002)
An example will make its use clear. So suppose, for example, that you want An example will make its use clear. So suppose, for example, that you want
to make a classifier for cats and dogs, but what happens if they both to make a classifier for cats and dogs, but what happens if they both
appear in one image? Or none of them? This layer allows you to handle appear in one image? Or none of them? This layer allows you to handle
...@@ -727,10 +736,32 @@ namespace dlib ...@@ -727,10 +736,32 @@ namespace dlib
- std::vector<float> both_label = {1.f, 1.f}; - std::vector<float> both_label = {1.f, 1.f};
- std::vector<float> none_label = {-1.f, -1.f}; - std::vector<float> none_label = {-1.f, -1.f};
!*/ !*/
public: public:
typedef std::vector<float> training_label_type; typedef std::vector<float> training_label_type;
typedef std::vector<float> output_label_type; typedef std::vector<float> output_label_type;
loss_multibinary_log_ (
);
/*!
ensures
- #get_gamma() == 0
!*/
loss_multibinary_log_(double gamma);
/*!
requires
- gamma >= 0
ensures
- #get_gamma() == gamma
!*/
double get_gamma() const;
/*!
ensures
- returns the gamma value used by the loss function.
!*/
template < template <
typename SUB_TYPE, typename SUB_TYPE,
typename label_iterator typename label_iterator
......
...@@ -17,6 +17,7 @@ New Features and Improvements: ...@@ -17,6 +17,7 @@ New Features and Improvements:
- Added ReOrg layer. - Added ReOrg layer.
- Added visitor to draw network architectures using the DOT language. - Added visitor to draw network architectures using the DOT language.
- Made Barlow Twins loss much faster for high dimensionality inputs. - Made Barlow Twins loss much faster for high dimensionality inputs.
- Added Focal loss gamma to loss_multibinary_log_.
Non-Backwards Compatible Changes: Non-Backwards Compatible Changes:
- Do not round coordinates in rectangle_transform (PR #2498). - Do not round coordinates in rectangle_transform (PR #2498).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment