Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
1ccd03fe
Unverified
Commit
1ccd03fe
authored
Feb 25, 2022
by
Adrià Arrufat
Committed by
GitHub
Feb 24, 2022
Browse files
Speed up Barlow Twins loss (#2519)
parent
50b78da5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
19 deletions
+32
-19
dlib/dnn/loss.h
dlib/dnn/loss.h
+32
-19
No files found.
dlib/dnn/loss.h
View file @
1ccd03fe
...
...
@@ -4016,9 +4016,6 @@ namespace dlib
// Normalize both batches independently across the batch dimension
const
double
eps
=
1e-4
;
resizable_tensor
za_norm
,
means_a
,
invstds_a
;
resizable_tensor
zb_norm
,
means_b
,
invstds_b
;
resizable_tensor
rms
,
rvs
,
g
,
b
;
g
.
set_size
(
1
,
sample_size
);
g
=
1
;
b
.
set_size
(
1
,
sample_size
);
...
...
@@ -4027,21 +4024,29 @@ namespace dlib
tt
::
batch_normalize
(
eps
,
zb_norm
,
means_b
,
invstds_b
,
1
,
rms
,
rvs
,
zb
,
g
,
b
);
// Compute the empirical cross-correlation matrix
resizable_tensor
eccm
;
eccm
.
set_size
(
sample_size
,
sample_size
);
tt
::
gemm
(
0
,
eccm
,
1
,
za_norm
,
true
,
zb_norm
,
false
);
eccm
/=
batch_size
;
// Compute the loss: MSE between eccm and the identity matrix.
// Off-diagonal terms are weighed by lambda.
const
matrix
<
float
>
C
=
mat
(
eccm
);
const
double
diagonal_loss
=
sum
(
squared
(
diag
(
C
)
-
1
));
const
double
off_diag_loss
=
sum
(
squared
(
C
-
diagm
(
diag
(
C
))));
double
loss
=
diagonal_loss
+
lambda
*
off_diag_loss
;
// Set sizes and setup auxiliary tensors
if
(
!
have_same_dimensions
(
eccm
,
identity
))
identity
=
identity_matrix
<
float
>
(
sample_size
);
if
(
!
have_same_dimensions
(
eccm
,
cdiag
))
cdiag
.
copy_size
(
eccm
);
if
(
!
have_same_dimensions
(
eccm
,
cdiag_1
))
cdiag_1
.
copy_size
(
eccm
);
if
(
!
have_same_dimensions
(
eccm
,
off_mask
))
off_mask
=
ones_matrix
<
float
>
(
sample_size
,
sample_size
)
-
identity_matrix
<
float
>
(
sample_size
);
if
(
!
have_same_dimensions
(
eccm
,
off_diag
))
off_diag
.
copy_size
(
eccm
);
if
(
!
have_same_dimensions
(
grad
,
grad_input
))
grad_input
.
copy_size
(
grad
);
if
(
!
have_same_dimensions
(
g_grad
,
g
))
g_grad
.
copy_size
(
g
);
if
(
!
have_same_dimensions
(
b_grad
,
b
))
b_grad
.
copy_size
(
b
);
// Loss gradient, which will be used as the input of the batch normalization gradient
resizable_tensor
grad_input
;
grad_input
.
copy_size
(
grad
);
auto
grad_input_a
=
split
(
grad_input
);
auto
grad_input_b
=
split
(
grad_input
,
offset
);
...
...
@@ -4051,11 +4056,15 @@ namespace dlib
// C = eccm
// D = off_mask: a mask that keeps only the elements outside the diagonal
// A diagonal matrix containing the diagonal of eccm
tt
::
multiply
(
false
,
cdiag
,
eccm
,
identity
);
// The diagonal of eccm minus the identity matrix
tt
::
affine_transform
(
cdiag_1
,
cdiag
,
identity
,
1
,
-
1
);
// diagonal term: sum((diag(A' * B) - vector(1)).^2)
// --------------------------------------------
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
resizable_tensor
cdiag_1
(
diagm
(
diag
(
mat
(
eccm
)
-
1
)));
tt
::
gemm
(
0
,
grad_input_a
,
2
,
zb_norm
,
false
,
cdiag_1
,
false
);
tt
::
gemm
(
0
,
grad_input_b
,
2
,
za_norm
,
false
,
cdiag_1
,
false
);
...
...
@@ -4063,22 +4072,21 @@ namespace dlib
// --------------------------------
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
resizable_tensor
off_mask
(
ones_matrix
<
float
>
(
sample_size
,
sample_size
)
-
identity_matrix
<
float
>
(
sample_size
));
resizable_tensor
off_diag
(
sample_size
,
sample_size
);
tt
::
multiply
(
false
,
off_diag
,
eccm
,
off_mask
);
tt
::
gemm
(
1
,
grad_input_a
,
2
*
lambda
,
zb_norm
,
false
,
off_diag
,
false
);
tt
::
gemm
(
1
,
grad_input_b
,
2
*
lambda
,
za_norm
,
false
,
off_diag
,
false
);
// Compute the batch norm gradients, g and b grads are not used
resizable_tensor
g_grad
,
b_grad
;
g_grad
.
copy_size
(
g
);
b_grad
.
copy_size
(
b
);
auto
gza
=
split
(
grad
);
auto
gzb
=
split
(
grad
,
offset
);
tt
::
batch_normalize_gradient
(
eps
,
grad_input_a
,
means_a
,
invstds_a
,
za
,
g
,
gza
,
g_grad
,
b_grad
);
tt
::
batch_normalize_gradient
(
eps
,
grad_input_b
,
means_b
,
invstds_b
,
zb
,
g
,
gzb
,
g_grad
,
b_grad
);
return
loss
;
// Compute the loss: MSE between eccm and the identity matrix.
// Off-diagonal terms are weighed by lambda.
const
double
diagonal_loss
=
sum
(
squared
(
mat
(
cdiag_1
)));
const
double
off_diag_loss
=
sum
(
squared
(
mat
(
off_diag
)));
return
diagonal_loss
+
lambda
*
off_diag_loss
;
}
float
get_lambda
()
const
{
return
lambda
;
}
...
...
@@ -4116,6 +4124,11 @@ namespace dlib
private:
float
lambda
=
0.0051
;
mutable
resizable_tensor
za_norm
,
means_a
,
invstds_a
;
mutable
resizable_tensor
zb_norm
,
means_b
,
invstds_b
;
mutable
resizable_tensor
rms
,
rvs
,
g
,
b
;
mutable
resizable_tensor
eccm
,
grad_input
,
g_grad
,
b_grad
;
mutable
resizable_tensor
cdiag
,
cdiag_1
,
identity
,
off_mask
,
off_diag
;
};
template
<
typename
SUBNET
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment