knn_monitor.py 4.48 KB
Newer Older
1
2
3
4
5
6
import torch.nn.functional as F
import torch
from megatron import print_rank_0, get_args, mpu
from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder

Vijay Korthikanti's avatar
Vijay Korthikanti committed
7
8
9
_FEATURE_BANK = None


10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def build_data_loader(dataset, drop_last=True, shuffle=False):
    """Data loader. Note that batch-size is the local (per GPU) batch-size."""
    # Sampler.
    args = get_args()
    micro_batch_size = 16
    num_workers = args.num_workers
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank,
        drop_last=drop_last, shuffle=shuffle
    )

    # Data loader. Note that batch size is the per GPU batch size.
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=micro_batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=num_workers,
        drop_last=not drop_last,
        pin_memory=True,
    )
    return data_loader


def compute_feature_bank(model):
    args = get_args()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
38
    global _FEATURE_BANK
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    feature_bank = []
    feature_label = []

    train_ds = ImageFolder(
        root=args.data_path[0],
        transform=ClassificationTransform((args.img_h, args.img_w), train=False),
        data_per_class_fraction=1.0
    )
    classes = len(train_ds.classes)
    dataloader = build_data_loader(train_ds)
     
    for m in model:
        m.eval()

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            images = batch[0].cuda().contiguous()
            labels = batch[1].cuda().contiguous()
            student_feature, teacher_feature = model[0](images)
            feature = F.normalize(teacher_feature.float(), dim=1)
            feature_bank.append(feature)
            feature_label.append(labels)
    
    for m in model:
        m.train()

    # [N', D]
    feature_bank = torch.cat(feature_bank, dim=0).contiguous()
    feature_label = torch.cat(feature_label, dim=0).contiguous()

    feature_banks = [torch.zeros_like(feature_bank)
                     for i in range(mpu.get_data_parallel_world_size())]
    torch.distributed.all_gather(feature_banks,
                                 feature_bank,
                                 group=mpu.get_data_parallel_group())

    assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()],
                              feature_bank))

    feature_labels = [torch.zeros_like(feature_label)
                      for i in range(mpu.get_data_parallel_world_size())]
    torch.distributed.all_gather(feature_labels,
                                 feature_label,
                                 group=mpu.get_data_parallel_group())

    # [D, N]
    feature_banks = torch.cat(feature_banks, dim=0).t().contiguous()
    # [N]
    feature_labels = torch.cat(feature_labels, dim=0).contiguous()
    print_rank_0("feature_banks size is {}".format(feature_banks.size()))
    print_rank_0("feature labels size is {}".format(feature_labels.size()))

Vijay Korthikanti's avatar
Vijay Korthikanti committed
91
92
93
94
95
96
97
    _FEATURE_BANK = (feature_banks, feature_labels, classes)


def get_feature_bank():
    global _FEATURE_BANK
    assert _FEATURE_BANK is not None
    return _FEATURE_BANK
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1),
                              dim=-1,
                              index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k,
                                classes,
                                device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1,
                                          index=sim_labels.view(-1, 1),
                                          value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(
            one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1),
            dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels