train.py 5.82 KB
Newer Older
1
2
3
4
5
6
import os

import torch
import torchvision.transforms as transforms
from loss import TripletMarginLoss
from model import EmbeddingNet
7
8
9
10
from sampler import PKSampler
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


def train_epoch(model, optimizer, criterion, data_loader, device, epoch, print_freq):
    model.train()
    running_loss = 0
    running_frac_pos_triplets = 0
    for i, data in enumerate(data_loader):
        optimizer.zero_grad()
        samples, targets = data[0].to(device), data[1].to(device)

        embeddings = model(samples)

        loss, frac_pos_triplets = criterion(embeddings, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_frac_pos_triplets += float(frac_pos_triplets)

        if i % print_freq == print_freq - 1:
            i += 1
            avg_loss = running_loss / print_freq
            avg_trip = 100.0 * running_frac_pos_triplets / print_freq
34
            print(f"[{epoch:d}, {i:d}] | loss: {avg_loss:.4f} | % avg hard triplets: {avg_trip:.2f}%")
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
            running_loss = 0
            running_frac_pos_triplets = 0


def find_best_threshold(dists, targets, device):
    best_thresh = 0.01
    best_correct = 0
    for thresh in torch.arange(0.0, 1.51, 0.01):
        predictions = dists <= thresh.to(device)
        correct = torch.sum(predictions == targets.to(device)).item()
        if correct > best_correct:
            best_thresh = thresh
            best_correct = correct

    accuracy = 100.0 * best_correct / dists.size(0)

    return best_thresh, accuracy


54
@torch.inference_mode()
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def evaluate(model, loader, device):
    model.eval()
    embeds, labels = [], []
    dists, targets = None, None

    for data in loader:
        samples, _labels = data[0].to(device), data[1]
        out = model(samples)
        embeds.append(out)
        labels.append(_labels)

    embeds = torch.cat(embeds, dim=0)
    labels = torch.cat(labels, dim=0)

    dists = torch.cdist(embeds, embeds)

    labels = labels.unsqueeze(0)
    targets = labels == labels.t()

    mask = torch.ones(dists.size()).triu() - torch.eye(dists.size(0))
    dists = dists[mask == 1]
    targets = targets[mask == 1]

    threshold, accuracy = find_best_threshold(dists, targets, device)

80
    print(f"accuracy: {accuracy:.3f}%, threshold: {threshold:.2f}")
81
82
83


def save(model, epoch, save_dir, file_name):
84
    file_name = "epoch_" + str(epoch) + "__" + file_name
85
86
87
88
89
    save_path = os.path.join(save_dir, file_name)
    torch.save(model.state_dict(), save_path)


def main(args):
90
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
91
92
93
94
95
96
97
98
99
100
101
102
103
    p = args.labels_per_batch
    k = args.samples_per_label
    batch_size = p * k

    model = EmbeddingNet()
    if args.resume:
        model.load_state_dict(torch.load(args.resume))

    model.to(device)

    criterion = TripletMarginLoss(margin=args.margin)
    optimizer = Adam(model.parameters(), lr=args.lr)

104
    transform = transforms.Compose(
105
106
107
108
109
110
        [
            transforms.Lambda(lambda image: image.convert("RGB")),
            transforms.Resize((224, 224)),
            transforms.PILToTensor(),
            transforms.ConvertImageDtype(torch.float),
        ]
111
    )
112
113
114
115
116
117
118
119
120
121
122
123

    # Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can
    # be replaced with any classification dataset.
    train_dataset = FashionMNIST(args.dataset_dir, train=True, transform=transform, download=True)
    test_dataset = FashionMNIST(args.dataset_dir, train=False, transform=transform, download=True)

    # targets is a list where the i_th element corresponds to the label of i_th dataset element.
    # This is required for PKSampler to randomly sample from exactly p classes. You will need to
    # construct targets while building your dataset. Some datasets (such as ImageFolder) have a
    # targets attribute with the same format.
    targets = train_dataset.targets.tolist()

124
125
126
127
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=PKSampler(targets, p, k), num_workers=args.workers
    )
    test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers)
128
129

    for epoch in range(1, args.epochs + 1):
130
        print("Training...")
131
132
        train_epoch(model, optimizer, criterion, train_loader, device, epoch, args.print_freq)

133
        print("Evaluating...")
134
135
        evaluate(model, test_loader, device)

136
137
        print("Saving...")
        save(model, epoch, args.save_dir, "ckpt.pth")
138
139
140
141


def parse_args():
    import argparse
142
143
144

    parser = argparse.ArgumentParser(description="PyTorch Embedding Learning")

145
    parser.add_argument("--dataset-dir", default="/tmp/fmnist/", type=str, help="FashionMNIST dataset directory path")
146
147
148
149
    parser.add_argument(
        "-p", "--labels-per-batch", default=8, type=int, help="Number of unique labels/classes per batch"
    )
    parser.add_argument("-k", "--samples-per-label", default=8, type=int, help="Number of samples per label in a batch")
150
151
152
153
    parser.add_argument("--eval-batch-size", default=512, type=int, help="batch size for evaluation")
    parser.add_argument("--epochs", default=10, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers")
    parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
154
    parser.add_argument("--margin", default=0.2, type=float, help="Triplet loss margin")
155
156
157
    parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
    parser.add_argument("--save-dir", default=".", type=str, help="Model save directory")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
158
159
160
161

    return parser.parse_args()


162
if __name__ == "__main__":
163
164
    args = parse_args()
    main(args)