Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
1b7053fe
Unverified
Commit
1b7053fe
authored
Mar 22, 2022
by
Adrià Arrufat
Committed by
GitHub
Mar 21, 2022
Browse files
Add focal gamma to loss_multibinary_log (#2546)
* Add focal gamma to loss_multibinary_log * update release notes
parent
f1a29f35
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
10 deletions
+59
-10
dlib/dnn/loss.h
dlib/dnn/loss.h
+27
-10
dlib/dnn/loss_abstract.h
dlib/dnn/loss_abstract.h
+31
-0
docs/docs/release_notes.xml
docs/docs/release_notes.xml
+1
-0
No files found.
dlib/dnn/loss.h
View file @
1b7053fe
...
@@ -770,6 +770,13 @@ namespace dlib
...
@@ -770,6 +770,13 @@ namespace dlib
typedef
std
::
vector
<
float
>
training_label_type
;
typedef
std
::
vector
<
float
>
training_label_type
;
typedef
std
::
vector
<
float
>
output_label_type
;
typedef
std
::
vector
<
float
>
output_label_type
;
loss_multibinary_log_
()
=
default
;
loss_multibinary_log_
(
double
gamma
)
:
gamma
(
gamma
)
{
DLIB_CASSERT
(
gamma
>=
0
);
}
template
<
template
<
typename
SUB_TYPE
,
typename
SUB_TYPE
,
typename
label_iterator
typename
label_iterator
...
@@ -842,43 +849,53 @@ namespace dlib
...
@@ -842,43 +849,53 @@ namespace dlib
if
(
y
>
0
)
if
(
y
>
0
)
{
{
const
float
temp
=
log1pexp
(
-
out_data
[
idx
]);
const
float
temp
=
log1pexp
(
-
out_data
[
idx
]);
const
float
focus
=
std
::
pow
(
1
-
g
[
idx
],
gamma
);
loss
+=
y
*
scale
*
temp
;
loss
+=
y
*
scale
*
temp
;
g
[
idx
]
=
y
*
scale
*
(
g
[
idx
]
-
1
);
g
[
idx
]
=
y
*
scale
*
focus
*
(
g
[
idx
]
*
(
gamma
*
temp
+
1
)
-
1
);
}
}
else
else
{
{
const
float
temp
=
-
(
-
out_data
[
idx
]
-
log1pexp
(
-
out_data
[
idx
]));
const
float
temp
=
-
(
-
out_data
[
idx
]
-
log1pexp
(
-
out_data
[
idx
]));
const
float
focus
=
std
::
pow
(
g
[
idx
],
gamma
);
loss
+=
-
y
*
scale
*
temp
;
loss
+=
-
y
*
scale
*
temp
;
g
[
idx
]
=
-
y
*
scale
*
g
[
idx
]
;
g
[
idx
]
=
-
y
*
scale
*
focus
*
g
[
idx
]
*
(
gamma
*
temp
+
1
)
;
}
}
}
}
}
}
return
loss
;
return
loss
;
}
}
friend
void
serialize
(
const
loss_multibinary_log_
&
,
std
::
ostream
&
out
)
double
get_gamma
()
const
{
return
gamma
;
}
friend
void
serialize
(
const
loss_multibinary_log_
&
item
,
std
::
ostream
&
out
)
{
{
serialize
(
"loss_multibinary_log_"
,
out
);
serialize
(
"loss_multibinary_log_2"
,
out
);
serialize
(
item
.
gamma
,
out
);
}
}
friend
void
deserialize
(
loss_multibinary_log_
&
,
std
::
istream
&
in
)
friend
void
deserialize
(
loss_multibinary_log_
&
item
,
std
::
istream
&
in
)
{
{
std
::
string
version
;
std
::
string
version
;
deserialize
(
version
,
in
);
deserialize
(
version
,
in
);
if
(
version
!=
"loss_multibinary_log_"
)
if
(
version
!=
"loss_multibinary_log_"
||
version
!=
"loss_multibinary_log_2"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::loss_multibinary_log_."
);
throw
serialization_error
(
"Unexpected version found while deserializing dlib::loss_multibinary_log_."
);
if
(
version
==
"loss_multibinary_log_2"
)
deserialize
(
item
.
gamma
,
in
);
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
loss_multibinary_log_
&
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
loss_multibinary_log_
&
item
)
{
{
out
<<
"loss_multibinary_log"
;
out
<<
"loss_multibinary_log
(gamma="
<<
item
.
gamma
<<
")
"
;
return
out
;
return
out
;
}
}
friend
void
to_xml
(
const
loss_multibinary_log_
&
/*
item
*/
,
std
::
ostream
&
out
)
friend
void
to_xml
(
const
loss_multibinary_log_
&
item
,
std
::
ostream
&
out
)
{
{
out
<<
"<loss_multibinary_log/>"
;
out
<<
"<loss_multibinary_log
gamma='"
<<
item
.
gamma
<<
"'
/>"
;
}
}
private:
double
gamma
=
0
;
};
};
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
...
...
dlib/dnn/loss_abstract.h
View file @
1b7053fe
...
@@ -718,6 +718,15 @@ namespace dlib
...
@@ -718,6 +718,15 @@ namespace dlib
To be more specific, this object contains a sigmoid layer followed by a
To be more specific, this object contains a sigmoid layer followed by a
cross-entropy layer.
cross-entropy layer.
Additionaly, this layer also contains a focusing parameter gamma, which
acts as a modulating factor to the cross-entropy layer by reducing the
relative loss for well-classified examples, and focusing on the difficult
ones. This gamma parameter makes this layer behave like the Focal loss,
presented in the paper:
Focal Loss for Dense Object Detection
by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár
(https://arxiv.org/abs/1708.02002)
An example will make its use clear. So suppose, for example, that you want
An example will make its use clear. So suppose, for example, that you want
to make a classifier for cats and dogs, but what happens if they both
to make a classifier for cats and dogs, but what happens if they both
appear in one image? Or none of them? This layer allows you to handle
appear in one image? Or none of them? This layer allows you to handle
...
@@ -727,10 +736,32 @@ namespace dlib
...
@@ -727,10 +736,32 @@ namespace dlib
- std::vector<float> both_label = {1.f, 1.f};
- std::vector<float> both_label = {1.f, 1.f};
- std::vector<float> none_label = {-1.f, -1.f};
- std::vector<float> none_label = {-1.f, -1.f};
!*/
!*/
public:
public:
typedef
std
::
vector
<
float
>
training_label_type
;
typedef
std
::
vector
<
float
>
training_label_type
;
typedef
std
::
vector
<
float
>
output_label_type
;
typedef
std
::
vector
<
float
>
output_label_type
;
loss_multibinary_log_
(
);
/*!
ensures
- #get_gamma() == 0
!*/
loss_multibinary_log_
(
double
gamma
);
/*!
requires
- gamma >= 0
ensures
- #get_gamma() == gamma
!*/
double
get_gamma
()
const
;
/*!
ensures
- returns the gamma value used by the loss function.
!*/
template
<
template
<
typename
SUB_TYPE
,
typename
SUB_TYPE
,
typename
label_iterator
typename
label_iterator
...
...
docs/docs/release_notes.xml
View file @
1b7053fe
...
@@ -17,6 +17,7 @@ New Features and Improvements:
...
@@ -17,6 +17,7 @@ New Features and Improvements:
- Added ReOrg layer.
- Added ReOrg layer.
- Added visitor to draw network architectures using the DOT language.
- Added visitor to draw network architectures using the DOT language.
- Made Barlow Twins loss much faster for high dimensionality inputs.
- Made Barlow Twins loss much faster for high dimensionality inputs.
- Added Focal loss gamma to loss_multibinary_log_.
Non-Backwards Compatible Changes:
Non-Backwards Compatible Changes:
- Do not round coordinates in rectangle_transform (PR #2498).
- Do not round coordinates in rectangle_transform (PR #2498).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment