Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
1ccd03fe
Unverified
Commit
1ccd03fe
authored
Feb 25, 2022
by
Adrià Arrufat
Committed by
GitHub
Feb 24, 2022
Browse files
Speed up Barlow Twins loss (#2519)
parent
50b78da5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
19 deletions
+32
-19
dlib/dnn/loss.h
dlib/dnn/loss.h
+32
-19
No files found.
dlib/dnn/loss.h
View file @
1ccd03fe
...
@@ -4016,9 +4016,6 @@ namespace dlib
...
@@ -4016,9 +4016,6 @@ namespace dlib
// Normalize both batches independently across the batch dimension
// Normalize both batches independently across the batch dimension
const
double
eps
=
1e-4
;
const
double
eps
=
1e-4
;
resizable_tensor
za_norm
,
means_a
,
invstds_a
;
resizable_tensor
zb_norm
,
means_b
,
invstds_b
;
resizable_tensor
rms
,
rvs
,
g
,
b
;
g
.
set_size
(
1
,
sample_size
);
g
.
set_size
(
1
,
sample_size
);
g
=
1
;
g
=
1
;
b
.
set_size
(
1
,
sample_size
);
b
.
set_size
(
1
,
sample_size
);
...
@@ -4027,21 +4024,29 @@ namespace dlib
...
@@ -4027,21 +4024,29 @@ namespace dlib
tt
::
batch_normalize
(
eps
,
zb_norm
,
means_b
,
invstds_b
,
1
,
rms
,
rvs
,
zb
,
g
,
b
);
tt
::
batch_normalize
(
eps
,
zb_norm
,
means_b
,
invstds_b
,
1
,
rms
,
rvs
,
zb
,
g
,
b
);
// Compute the empirical cross-correlation matrix
// Compute the empirical cross-correlation matrix
resizable_tensor
eccm
;
eccm
.
set_size
(
sample_size
,
sample_size
);
eccm
.
set_size
(
sample_size
,
sample_size
);
tt
::
gemm
(
0
,
eccm
,
1
,
za_norm
,
true
,
zb_norm
,
false
);
tt
::
gemm
(
0
,
eccm
,
1
,
za_norm
,
true
,
zb_norm
,
false
);
eccm
/=
batch_size
;
eccm
/=
batch_size
;
// Compute the loss: MSE between eccm and the identity matrix.
// Set sizes and setup auxiliary tensors
// Off-diagonal terms are weighed by lambda.
if
(
!
have_same_dimensions
(
eccm
,
identity
))
const
matrix
<
float
>
C
=
mat
(
eccm
);
identity
=
identity_matrix
<
float
>
(
sample_size
);
const
double
diagonal_loss
=
sum
(
squared
(
diag
(
C
)
-
1
));
if
(
!
have_same_dimensions
(
eccm
,
cdiag
))
const
double
off_diag_loss
=
sum
(
squared
(
C
-
diagm
(
diag
(
C
))));
cdiag
.
copy_size
(
eccm
);
double
loss
=
diagonal_loss
+
lambda
*
off_diag_loss
;
if
(
!
have_same_dimensions
(
eccm
,
cdiag_1
))
cdiag_1
.
copy_size
(
eccm
);
if
(
!
have_same_dimensions
(
eccm
,
off_mask
))
off_mask
=
ones_matrix
<
float
>
(
sample_size
,
sample_size
)
-
identity_matrix
<
float
>
(
sample_size
);
if
(
!
have_same_dimensions
(
eccm
,
off_diag
))
off_diag
.
copy_size
(
eccm
);
if
(
!
have_same_dimensions
(
grad
,
grad_input
))
grad_input
.
copy_size
(
grad
);
if
(
!
have_same_dimensions
(
g_grad
,
g
))
g_grad
.
copy_size
(
g
);
if
(
!
have_same_dimensions
(
b_grad
,
b
))
b_grad
.
copy_size
(
b
);
// Loss gradient, which will be used as the input of the batch normalization gradient
// Loss gradient, which will be used as the input of the batch normalization gradient
resizable_tensor
grad_input
;
grad_input
.
copy_size
(
grad
);
auto
grad_input_a
=
split
(
grad_input
);
auto
grad_input_a
=
split
(
grad_input
);
auto
grad_input_b
=
split
(
grad_input
,
offset
);
auto
grad_input_b
=
split
(
grad_input
,
offset
);
...
@@ -4051,11 +4056,15 @@ namespace dlib
...
@@ -4051,11 +4056,15 @@ namespace dlib
// C = eccm
// C = eccm
// D = off_mask: a mask that keeps only the elements outside the diagonal
// D = off_mask: a mask that keeps only the elements outside the diagonal
// A diagonal matrix containing the diagonal of eccm
tt
::
multiply
(
false
,
cdiag
,
eccm
,
identity
);
// The diagonal of eccm minus the identity matrix
tt
::
affine_transform
(
cdiag_1
,
cdiag
,
identity
,
1
,
-
1
);
// diagonal term: sum((diag(A' * B) - vector(1)).^2)
// diagonal term: sum((diag(A' * B) - vector(1)).^2)
// --------------------------------------------
// --------------------------------------------
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
resizable_tensor
cdiag_1
(
diagm
(
diag
(
mat
(
eccm
)
-
1
)));
tt
::
gemm
(
0
,
grad_input_a
,
2
,
zb_norm
,
false
,
cdiag_1
,
false
);
tt
::
gemm
(
0
,
grad_input_a
,
2
,
zb_norm
,
false
,
cdiag_1
,
false
);
tt
::
gemm
(
0
,
grad_input_b
,
2
,
za_norm
,
false
,
cdiag_1
,
false
);
tt
::
gemm
(
0
,
grad_input_b
,
2
,
za_norm
,
false
,
cdiag_1
,
false
);
...
@@ -4063,22 +4072,21 @@ namespace dlib
...
@@ -4063,22 +4072,21 @@ namespace dlib
// --------------------------------
// --------------------------------
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
resizable_tensor
off_mask
(
ones_matrix
<
float
>
(
sample_size
,
sample_size
)
-
identity_matrix
<
float
>
(
sample_size
));
resizable_tensor
off_diag
(
sample_size
,
sample_size
);
tt
::
multiply
(
false
,
off_diag
,
eccm
,
off_mask
);
tt
::
multiply
(
false
,
off_diag
,
eccm
,
off_mask
);
tt
::
gemm
(
1
,
grad_input_a
,
2
*
lambda
,
zb_norm
,
false
,
off_diag
,
false
);
tt
::
gemm
(
1
,
grad_input_a
,
2
*
lambda
,
zb_norm
,
false
,
off_diag
,
false
);
tt
::
gemm
(
1
,
grad_input_b
,
2
*
lambda
,
za_norm
,
false
,
off_diag
,
false
);
tt
::
gemm
(
1
,
grad_input_b
,
2
*
lambda
,
za_norm
,
false
,
off_diag
,
false
);
// Compute the batch norm gradients, g and b grads are not used
// Compute the batch norm gradients, g and b grads are not used
resizable_tensor
g_grad
,
b_grad
;
g_grad
.
copy_size
(
g
);
b_grad
.
copy_size
(
b
);
auto
gza
=
split
(
grad
);
auto
gza
=
split
(
grad
);
auto
gzb
=
split
(
grad
,
offset
);
auto
gzb
=
split
(
grad
,
offset
);
tt
::
batch_normalize_gradient
(
eps
,
grad_input_a
,
means_a
,
invstds_a
,
za
,
g
,
gza
,
g_grad
,
b_grad
);
tt
::
batch_normalize_gradient
(
eps
,
grad_input_a
,
means_a
,
invstds_a
,
za
,
g
,
gza
,
g_grad
,
b_grad
);
tt
::
batch_normalize_gradient
(
eps
,
grad_input_b
,
means_b
,
invstds_b
,
zb
,
g
,
gzb
,
g_grad
,
b_grad
);
tt
::
batch_normalize_gradient
(
eps
,
grad_input_b
,
means_b
,
invstds_b
,
zb
,
g
,
gzb
,
g_grad
,
b_grad
);
return
loss
;
// Compute the loss: MSE between eccm and the identity matrix.
// Off-diagonal terms are weighed by lambda.
const
double
diagonal_loss
=
sum
(
squared
(
mat
(
cdiag_1
)));
const
double
off_diag_loss
=
sum
(
squared
(
mat
(
off_diag
)));
return
diagonal_loss
+
lambda
*
off_diag_loss
;
}
}
float
get_lambda
()
const
{
return
lambda
;
}
float
get_lambda
()
const
{
return
lambda
;
}
...
@@ -4116,6 +4124,11 @@ namespace dlib
...
@@ -4116,6 +4124,11 @@ namespace dlib
private:
private:
float
lambda
=
0.0051
;
float
lambda
=
0.0051
;
mutable
resizable_tensor
za_norm
,
means_a
,
invstds_a
;
mutable
resizable_tensor
zb_norm
,
means_b
,
invstds_b
;
mutable
resizable_tensor
rms
,
rvs
,
g
,
b
;
mutable
resizable_tensor
eccm
,
grad_input
,
g_grad
,
b_grad
;
mutable
resizable_tensor
cdiag
,
cdiag_1
,
identity
,
off_mask
,
off_diag
;
};
};
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment