Commit b590f8c6 authored by Will Brennan's avatar Will Brennan Committed by Francisco Massa
Browse files

Fix broken bitwise operation in Similarity Reference loss (#1604)

parent b8e3e969
......@@ -77,7 +77,7 @@ def batch_all_triplet_loss(labels, embeddings, margin, p):
def _get_triplet_mask(labels):
# Check that i, j and k are distinct
indices_equal = torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)
indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
indices_not_equal = ~indices_equal
i_not_equal_j = indices_not_equal.unsqueeze(2)
i_not_equal_k = indices_not_equal.unsqueeze(1)
......@@ -96,7 +96,7 @@ def _get_triplet_mask(labels):
def _get_anchor_positive_triplet_mask(labels):
# Check that i and j are distinct
indices_equal = torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)
indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment