Unverified Commit 1b7053fe authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add focal gamma to loss_multibinary_log (#2546)

* Add focal gamma to loss_multibinary_log

* update release notes
parent f1a29f35
......@@ -770,6 +770,13 @@ namespace dlib
typedef std::vector<float> training_label_type;
typedef std::vector<float> output_label_type;
loss_multibinary_log_() = default;
loss_multibinary_log_(double gamma) : gamma(gamma)
{
DLIB_CASSERT(gamma >= 0);
}
template <
typename SUB_TYPE,
typename label_iterator
......@@ -842,43 +849,53 @@ namespace dlib
if (y > 0)
{
const float temp = log1pexp(-out_data[idx]);
const float focus = std::pow(1 - g[idx], gamma);
loss += y * scale * temp;
g[idx] = y * scale * (g[idx] - 1);
g[idx] = y * scale * focus * (g[idx] * (gamma * temp + 1) - 1);
}
else
{
const float temp = -(-out_data[idx] - log1pexp(-out_data[idx]));
const float focus = std::pow(g[idx], gamma);
loss += -y * scale * temp;
g[idx] = -y * scale * g[idx];
g[idx] = -y * scale * focus * g[idx] * (gamma * temp + 1);
}
}
}
return loss;
}
friend void serialize(const loss_multibinary_log_&, std::ostream& out)
double get_gamma () const { return gamma; }
friend void serialize(const loss_multibinary_log_& item, std::ostream& out)
{
serialize("loss_multibinary_log_", out);
serialize("loss_multibinary_log_2", out);
serialize(item.gamma, out);
}
friend void deserialize(loss_multibinary_log_&, std::istream& in)
friend void deserialize(loss_multibinary_log_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_multibinary_log_")
if (version != "loss_multibinary_log_" || version != "loss_multibinary_log_2")
throw serialization_error("Unexpected version found while deserializing dlib::loss_multibinary_log_.");
if (version == "loss_multibinary_log_2")
deserialize(item.gamma, in);
}
friend std::ostream& operator<<(std::ostream& out, const loss_multibinary_log_& )
friend std::ostream& operator<<(std::ostream& out, const loss_multibinary_log_& item)
{
out << "loss_multibinary_log";
out << "loss_multibinary_log (gamma=" << item.gamma << ")";
return out;
}
friend void to_xml(const loss_multibinary_log_& /*item*/, std::ostream& out)
friend void to_xml(const loss_multibinary_log_& item, std::ostream& out)
{
out << "<loss_multibinary_log/>";
out << "<loss_multibinary_log gamma='" << item.gamma << "'/>";
}
private:
double gamma = 0;
};
template <typename SUBNET>
......
......@@ -718,6 +718,15 @@ namespace dlib
To be more specific, this object contains a sigmoid layer followed by a
cross-entropy layer.
Additionaly, this layer also contains a focusing parameter gamma, which
acts as a modulating factor to the cross-entropy layer by reducing the
relative loss for well-classified examples, and focusing on the difficult
ones. This gamma parameter makes this layer behave like the Focal loss,
presented in the paper:
Focal Loss for Dense Object Detection
by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár
(https://arxiv.org/abs/1708.02002)
An example will make its use clear. So suppose, for example, that you want
to make a classifier for cats and dogs, but what happens if they both
appear in one image? Or none of them? This layer allows you to handle
......@@ -727,10 +736,32 @@ namespace dlib
- std::vector<float> both_label = {1.f, 1.f};
- std::vector<float> none_label = {-1.f, -1.f};
!*/
public:
typedef std::vector<float> training_label_type;
typedef std::vector<float> output_label_type;
loss_multibinary_log_ (
);
/*!
ensures
- #get_gamma() == 0
!*/
loss_multibinary_log_(double gamma);
/*!
requires
- gamma >= 0
ensures
- #get_gamma() == gamma
!*/
double get_gamma() const;
/*!
ensures
- returns the gamma value used by the loss function.
!*/
template <
typename SUB_TYPE,
typename label_iterator
......
......@@ -17,6 +17,7 @@ New Features and Improvements:
- Added ReOrg layer.
- Added visitor to draw network architectures using the DOT language.
- Made Barlow Twins loss much faster for high dimensionality inputs.
- Added Focal loss gamma to loss_multibinary_log_.
Non-Backwards Compatible Changes:
- Do not round coordinates in rectangle_transform (PR #2498).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment