Commit bbd363ca authored by Daksh Jotwani's avatar Daksh Jotwani Committed by Francisco Massa
Browse files

Similarity learning reference code (#1101)

* Add loss, sampler, and train script

* Fix train script

* Add argparse

* Fix lint

* Change f strings to .format()

* Remove unused imports

* Change TripletMarginLoss to extend nn.Module

* Load eye uint8 tensors directly on device

* Refactor model.py to backbone=None

* Add docstring for PKSampler

* Refactor evaluate() to take loader as arg instead

* Change eval method to cat embeddings all at once

* Add dataset comments

* Add README.md

* Add tests for sampler

* Refactor threshold finder to helper method

* Refactor targets comment

* Fix lint

* Rename embedding to similarity (More consistent with existing literature)
parent 8837e0ef
# Similarity Learning Using Triplet Loss #
In this reference, we use triplet loss to learn embeddings which can be used to differentiate images. This learning technique was popularized by [FaceNet: A Unified Embedding for Face Recognition and Clustering](https://arxiv.org/abs/1503.03832) and has been quite effective in learning embeddings to differentiate between faces.
This reference can be directly applied to the following use cases:
* You have an unknown number of classes and would like to train a model to learn how to differentiate between them.
* You want to train a model to learn a distance-based metric between samples. For example, learning a distance-based similarity measure between faces.
### Training ###
By default, the training script trains ResNet50 on the FashionMNIST Dataset to learn image embeddings which can be used to differentiate between images by measuring the euclidean distance between embeddings. This can be changed as per your requirements.
Image embeddings of the same class should be 'close' to each other, while image embeddings between different classes should be 'far' away.
To run the training script:
```bash
python train.py -h # Lists all optional arguments
python train.py # Runs training script with default args
```
Running the training script as is should yield 97% accuracy on the FMNIST test set within 10 epochs.
### Loss ###
`TripletMarginLoss` is a loss function which takes in a triplet of samples. A valid triplet has an:
1. Anchor: a sample from the dataset
2. Positive: another sample with the same label/group as the anchor (Generally, positive != anchor)
3. Negative: a sample with a different label/group from the anchor
`TripletMarginLoss` (refer to `loss.py`) does the following:
```python
loss = max(dist(anchor, positive) - dist(anchor, negative) + margin, 0)
```
Where `dist` is a distance function. Minimizing this function effectively leads to minimizing `dist(anchor, positive)` and maximizing `dist(anchor, negative)`.
The FaceNet paper describe this loss in more detail.
### Sampler ###
In order to generate valid triplets from a batch of samples, we need to make sure that each batch has multiple samples with the same label. We do this using `PKSampler` (refer to `sampler.py`), which ensures that each batch of size `p * k` will have samples from exactly `p` classes and `k` samples per class.
### Triplet Mining ###
`TripletMarginLoss` currently supports the following mining techniques:
* `batch_all`: Generates all possible triplets from a batch and excludes the triplets which are 'easy' (which have `loss = 0`) before passing it through the loss function.
* `batch_hard`: For every anchor, `batch_hard` creates a triplet with the 'hardest' positive (farthest positive) and negative (closest negative).
These mining strategies usually speed up training.
This [webpage](https://omoindrot.github.io/triplet-loss) describes the sampling and mining strategies in more detail.
'''
Pytorch adaptation of https://omoindrot.github.io/triplet-loss
https://github.com/omoindrot/tensorflow-triplet-loss
'''
import torch
import torch.nn as nn
class TripletMarginLoss(nn.Module):
def __init__(self, margin=1.0, p=2., mining='batch_all'):
super(TripletMarginLoss, self).__init__()
self.margin = margin
self.p = p
self.mining = mining
if mining == 'batch_all':
self.loss_fn = batch_all_triplet_loss
if mining == 'batch_hard':
self.loss_fn = batch_hard_triplet_loss
def forward(self, embeddings, labels):
return self.loss_fn(labels, embeddings, self.margin, self.p)
def batch_hard_triplet_loss(labels, embeddings, margin, p):
pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# hardest positive for every anchor
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
# Add max value in each row to invalid negatives
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# hardest negative for every anchor
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
triplet_loss = hardest_positive_dist - hardest_negative_dist + margin
triplet_loss[triplet_loss < 0] = 0
triplet_loss = triplet_loss.mean()
return triplet_loss, -1
def batch_all_triplet_loss(labels, embeddings, margin, p):
pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
anchor_positive_dist = pairwise_dist.unsqueeze(2)
anchor_negative_dist = pairwise_dist.unsqueeze(1)
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
mask = _get_triplet_mask(labels)
triplet_loss = mask.float() * triplet_loss
# Remove negative losses (easy triplets)
triplet_loss[triplet_loss < 0] = 0
# Count number of positive triplets (where triplet_loss > 0)
valid_triplets = triplet_loss[triplet_loss > 1e-16]
num_positive_triplets = valid_triplets.size(0)
num_valid_triplets = mask.sum()
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
# Get final mean triplet loss over the positive valid triplets
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
return triplet_loss, fraction_positive_triplets
def _get_triplet_mask(labels):
# Check that i, j and k are distinct
indices_equal = torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)
indices_not_equal = ~indices_equal
i_not_equal_j = indices_not_equal.unsqueeze(2)
i_not_equal_k = indices_not_equal.unsqueeze(1)
j_not_equal_k = indices_not_equal.unsqueeze(0)
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
i_equal_j = label_equal.unsqueeze(2)
i_equal_k = label_equal.unsqueeze(1)
valid_labels = ~i_equal_k & i_equal_j
return valid_labels & distinct_indices
def _get_anchor_positive_triplet_mask(labels):
# Check that i and j are distinct
indices_equal = torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device)
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
return labels_equal & indices_not_equal
def _get_anchor_negative_triplet_mask(labels):
return labels.unsqueeze(0) != labels.unsqueeze(1)
import torch
import torch.nn as nn
import torchvision.models as models
class EmbeddingNet(nn.Module):
def __init__(self, backbone=None):
super(EmbeddingNet, self).__init__()
if backbone is None:
backbone = models.resnet50(num_classes=128)
self.backbone = backbone
def forward(self, x):
x = self.backbone(x)
x = nn.functional.normalize(x, dim=1)
return x
import torch
from torch.utils.data.sampler import Sampler
from collections import defaultdict
import random
def create_groups(groups, k):
"""Bins sample indices with respect to groups, remove bins with less than k samples
Args:
groups (list[int]): where ith index stores ith sample's group id
Returns:
defaultdict[list]: Bins of sample indices, binned by group_idx
"""
group_samples = defaultdict(list)
for sample_idx, group_idx in enumerate(groups):
group_samples[group_idx].append(sample_idx)
keys_to_remove = []
for key in group_samples:
if len(group_samples[key]) < k:
keys_to_remove.append(key)
continue
for key in keys_to_remove:
group_samples.pop(key)
return group_samples
class PKSampler(Sampler):
"""
Randomly samples from a dataset while ensuring that each batch (of size p * k)
includes samples from exactly p labels, with k samples for each label.
Args:
groups (list[int]): List where the ith entry is the group_id/label of the ith sample in the dataset.
p (int): Number of labels/groups to be sampled from in a batch
k (int): Number of samples for each label/group in a batch
"""
def __init__(self, groups, p, k):
self.p = p
self.k = k
self.groups = create_groups(groups, self.k)
# Ensures there are enough classes to sample from
assert len(self.groups) >= p
def __iter__(self):
# Shuffle samples within groups
for key in self.groups:
random.shuffle(self.groups[key])
# Keep track of the number of samples left for each group
group_samples_remaining = {}
for key in self.groups:
group_samples_remaining[key] = len(self.groups[key])
while len(group_samples_remaining) > self.p:
# Select p groups at random from valid/remaining groups
group_ids = list(group_samples_remaining.keys())
selected_group_idxs = torch.multinomial(torch.ones(len(group_ids)), self.p).tolist()
for i in selected_group_idxs:
group_id = group_ids[i]
group = self.groups[group_id]
for _ in range(self.k):
# No need to pick samples at random since group samples are shuffled
sample_idx = len(group) - group_samples_remaining[group_id]
yield group[sample_idx]
group_samples_remaining[group_id] -= 1
# Don't sample from group if it has less than k samples remaining
if group_samples_remaining[group_id] < self.k:
group_samples_remaining.pop(group_id)
import unittest
from collections import defaultdict
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
import torchvision.transforms as transforms
from sampler import PKSampler
class Tester(unittest.TestCase):
def test_pksampler(self):
p, k = 16, 4
# Ensure sampler does not allow p to be greater than num_classes
dataset = FakeData(size=100, num_classes=10, image_size=(3, 1, 1))
targets = [target.item() for _, target in dataset]
self.assertRaises(AssertionError, PKSampler, targets, p, k)
# Ensure p, k constraints on batch
dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1),
transform=transforms.ToTensor())
targets = [target.item() for _, target in dataset]
sampler = PKSampler(targets, p, k)
loader = DataLoader(dataset, batch_size=p * k, sampler=sampler)
for _, labels in loader:
bins = defaultdict(int)
for l in labels.tolist():
bins[l] += 1
# Ensure that each batch has samples from exactly p classes
self.assertEqual(len(bins), p)
# Ensure that there are k samples from each class
for l in bins:
self.assertEqual(bins[l], k)
if __name__ == '__main__':
unittest.main()
import os
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
from loss import TripletMarginLoss
from sampler import PKSampler
from model import EmbeddingNet
def train_epoch(model, optimizer, criterion, data_loader, device, epoch, print_freq):
model.train()
running_loss = 0
running_frac_pos_triplets = 0
for i, data in enumerate(data_loader):
optimizer.zero_grad()
samples, targets = data[0].to(device), data[1].to(device)
embeddings = model(samples)
loss, frac_pos_triplets = criterion(embeddings, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_frac_pos_triplets += float(frac_pos_triplets)
if i % print_freq == print_freq - 1:
i += 1
avg_loss = running_loss / print_freq
avg_trip = 100.0 * running_frac_pos_triplets / print_freq
print('[{:d}, {:d}] | loss: {:.4f} | % avg hard triplets: {:.2f}%'.format(epoch, i, avg_loss, avg_trip))
running_loss = 0
running_frac_pos_triplets = 0
def find_best_threshold(dists, targets, device):
best_thresh = 0.01
best_correct = 0
for thresh in torch.arange(0.0, 1.51, 0.01):
predictions = dists <= thresh.to(device)
correct = torch.sum(predictions == targets.to(device)).item()
if correct > best_correct:
best_thresh = thresh
best_correct = correct
accuracy = 100.0 * best_correct / dists.size(0)
return best_thresh, accuracy
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
embeds, labels = [], []
dists, targets = None, None
for data in loader:
samples, _labels = data[0].to(device), data[1]
out = model(samples)
embeds.append(out)
labels.append(_labels)
embeds = torch.cat(embeds, dim=0)
labels = torch.cat(labels, dim=0)
dists = torch.cdist(embeds, embeds)
labels = labels.unsqueeze(0)
targets = labels == labels.t()
mask = torch.ones(dists.size()).triu() - torch.eye(dists.size(0))
dists = dists[mask == 1]
targets = targets[mask == 1]
threshold, accuracy = find_best_threshold(dists, targets, device)
print('accuracy: {:.3f}%, threshold: {:.2f}'.format(accuracy, threshold))
def save(model, epoch, save_dir, file_name):
file_name = 'epoch_' + str(epoch) + '__' + file_name
save_path = os.path.join(save_dir, file_name)
torch.save(model.state_dict(), save_path)
def main(args):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
p = args.labels_per_batch
k = args.samples_per_label
batch_size = p * k
model = EmbeddingNet()
if args.resume:
model.load_state_dict(torch.load(args.resume))
model.to(device)
criterion = TripletMarginLoss(margin=args.margin)
optimizer = Adam(model.parameters(), lr=args.lr)
transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')),
transforms.Resize((224, 224)),
transforms.ToTensor()])
# Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can
# be replaced with any classification dataset.
train_dataset = FashionMNIST(args.dataset_dir, train=True, transform=transform, download=True)
test_dataset = FashionMNIST(args.dataset_dir, train=False, transform=transform, download=True)
# targets is a list where the i_th element corresponds to the label of i_th dataset element.
# This is required for PKSampler to randomly sample from exactly p classes. You will need to
# construct targets while building your dataset. Some datasets (such as ImageFolder) have a
# targets attribute with the same format.
targets = train_dataset.targets.tolist()
train_loader = DataLoader(train_dataset, batch_size=batch_size,
sampler=PKSampler(targets, p, k),
num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size,
shuffle=False,
num_workers=args.workers)
for epoch in range(1, args.epochs + 1):
print('Training...')
train_epoch(model, optimizer, criterion, train_loader, device, epoch, args.print_freq)
print('Evaluating...')
evaluate(model, test_loader, device)
print('Saving...')
save(model, epoch, args.save_dir, 'ckpt.pth')
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Embedding Learning')
parser.add_argument('--dataset-dir', default='/tmp/fmnist/',
help='FashionMNIST dataset directory path')
parser.add_argument('-p', '--labels-per-batch', default=8, type=int,
help='Number of unique labels/classes per batch')
parser.add_argument('-k', '--samples-per-label', default=8, type=int,
help='Number of samples per label in a batch')
parser.add_argument('--eval-batch-size', default=512, type=int)
parser.add_argument('--epochs', default=10, type=int, metavar='N',
help='Number of training epochs to run')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='Number of data loading workers')
parser.add_argument('--lr', default=0.0001, type=float, help='Learning rate')
parser.add_argument('--margin', default=0.2, type=float, help='Triplet loss margin')
parser.add_argument('--print-freq', default=20, type=int, help='Print frequency')
parser.add_argument('--save-dir', default='.', help='Model save directory')
parser.add_argument('--resume', default='', help='Resume from checkpoint')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment