Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
1ccd03fe
Unverified
Commit
1ccd03fe
authored
Feb 25, 2022
by
Adrià Arrufat
Committed by
GitHub
Feb 24, 2022
Browse files
Speed up Barlow Twins loss (#2519)
parent
50b78da5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
19 deletions
+32
-19
dlib/dnn/loss.h
dlib/dnn/loss.h
+32
-19
No files found.
dlib/dnn/loss.h
View file @
1ccd03fe
...
@@ -4016,9 +4016,6 @@ namespace dlib
...
@@ -4016,9 +4016,6 @@ namespace dlib
// Normalize both batches independently across the batch dimension
// Normalize both batches independently across the batch dimension
const double eps = 1e-4;
const double eps = 1e-4;
resizable_tensor
za_norm
,
means_a
,
invstds_a
;
resizable_tensor
zb_norm
,
means_b
,
invstds_b
;
resizable_tensor
rms
,
rvs
,
g
,
b
;
g.set_size(1, sample_size);
g.set_size(1, sample_size);
g = 1;
g = 1;
b.set_size(1, sample_size);
b.set_size(1, sample_size);
...
@@ -4027,21 +4024,29 @@ namespace dlib
...
@@ -4027,21 +4024,29 @@ namespace dlib
tt::batch_normalize(eps, zb_norm, means_b, invstds_b, 1, rms, rvs, zb, g, b);
tt::batch_normalize(eps, zb_norm, means_b, invstds_b, 1, rms, rvs, zb, g, b);
// Compute the empirical cross-correlation matrix
// Compute the empirical cross-correlation matrix
resizable_tensor
eccm
;
eccm.set_size(sample_size, sample_size);
eccm.set_size(sample_size, sample_size);
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
eccm /= batch_size;
eccm /= batch_size;
// Compute the loss: MSE between eccm and the identity matrix.
// Set sizes and setup auxiliary tensors
// Off-diagonal terms are weighed by lambda.
if (!have_same_dimensions(eccm, identity))
const
matrix
<
float
>
C
=
mat
(
eccm
);
identity = identity_matrix<float>(sample_size);
const
double
diagonal_loss
=
sum
(
squared
(
diag
(
C
)
-
1
));
if (!have_same_dimensions(eccm, cdiag))
const
double
off_diag_loss
=
sum
(
squared
(
C
-
diagm
(
diag
(
C
))));
cdiag.copy_size(eccm);
double
loss
=
diagonal_loss
+
lambda
*
off_diag_loss
;
if (!have_same_dimensions(eccm, cdiag_1))
cdiag_1.copy_size(eccm);
if (!have_same_dimensions(eccm, off_mask))
off_mask = ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size);
if (!have_same_dimensions(eccm, off_diag))
off_diag.copy_size(eccm);
if (!have_same_dimensions(grad, grad_input))
grad_input.copy_size(grad);
if (!have_same_dimensions(g_grad, g))
g_grad.copy_size(g);
if (!have_same_dimensions(b_grad, b))
b_grad.copy_size(b);
// Loss gradient, which will be used as the input of the batch normalization gradient
// Loss gradient, which will be used as the input of the batch normalization gradient
resizable_tensor
grad_input
;
grad_input
.
copy_size
(
grad
);
auto grad_input_a = split(grad_input);
auto grad_input_a = split(grad_input);
auto grad_input_b = split(grad_input, offset);
auto grad_input_b = split(grad_input, offset);
...
@@ -4051,11 +4056,15 @@ namespace dlib
...
@@ -4051,11 +4056,15 @@ namespace dlib
// C = eccm
// C = eccm
// D = off_mask: a mask that keeps only the elements outside the diagonal
// D = off_mask: a mask that keeps only the elements outside the diagonal
// A diagonal matrix containing the diagonal of eccm
tt::multiply(false, cdiag, eccm, identity);
// The diagonal of eccm minus the identity matrix
tt::affine_transform(cdiag_1, cdiag, identity, 1, -1);
// diagonal term: sum((diag(A' * B) - vector(1)).^2)
// diagonal term: sum((diag(A' * B) - vector(1)).^2)
// --------------------------------------------
// --------------------------------------------
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
resizable_tensor
cdiag_1
(
diagm
(
diag
(
mat
(
eccm
)
-
1
)));
tt::gemm(0, grad_input_a, 2, zb_norm, false, cdiag_1, false);
tt::gemm(0, grad_input_a, 2, zb_norm, false, cdiag_1, false);
tt::gemm(0, grad_input_b, 2, za_norm, false, cdiag_1, false);
tt::gemm(0, grad_input_b, 2, za_norm, false, cdiag_1, false);
...
@@ -4063,22 +4072,21 @@ namespace dlib
...
@@ -4063,22 +4072,21 @@ namespace dlib
// --------------------------------
// --------------------------------
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
resizable_tensor
off_mask
(
ones_matrix
<
float
>
(
sample_size
,
sample_size
)
-
identity_matrix
<
float
>
(
sample_size
));
resizable_tensor
off_diag
(
sample_size
,
sample_size
);
tt::multiply(false, off_diag, eccm, off_mask);
tt::multiply(false, off_diag, eccm, off_mask);
tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false);
tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false);
tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false);
tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false);
// Compute the batch norm gradients, g and b grads are not used
// Compute the batch norm gradients, g and b grads are not used
resizable_tensor
g_grad
,
b_grad
;
g_grad
.
copy_size
(
g
);
b_grad
.
copy_size
(
b
);
auto gza = split(grad);
auto gza = split(grad);
auto gzb = split(grad, offset);
auto gzb = split(grad, offset);
tt::batch_normalize_gradient(eps, grad_input_a, means_a, invstds_a, za, g, gza, g_grad, b_grad);
tt::batch_normalize_gradient(eps, grad_input_a, means_a, invstds_a, za, g, gza, g_grad, b_grad);
tt::batch_normalize_gradient(eps, grad_input_b, means_b, invstds_b, zb, g, gzb, g_grad, b_grad);
tt::batch_normalize_gradient(eps, grad_input_b, means_b, invstds_b, zb, g, gzb, g_grad, b_grad);
return
loss
;
// Compute the loss: MSE between eccm and the identity matrix.
// Off-diagonal terms are weighed by lambda.
const double diagonal_loss = sum(squared(mat(cdiag_1)));
const double off_diag_loss = sum(squared(mat(off_diag)));
return diagonal_loss + lambda * off_diag_loss;
}
}
float get_lambda() const { return lambda; }
float get_lambda() const { return lambda; }
...
@@ -4116,6 +4124,11 @@ namespace dlib
...
@@ -4116,6 +4124,11 @@ namespace dlib
private:
private:
float lambda = 0.0051;
float lambda = 0.0051;
mutable resizable_tensor za_norm, means_a, invstds_a;
mutable resizable_tensor zb_norm, means_b, invstds_b;
mutable resizable_tensor rms, rvs, g, b;
mutable resizable_tensor eccm, grad_input, g_grad, b_grad;
mutable resizable_tensor cdiag, cdiag_1, identity, off_mask, off_diag;
};
};
template <typename SUBNET>
template <typename SUBNET>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment