model.py 395 Bytes
Newer Older
1
2
3
4
5
6
import torch.nn as nn
import torchvision.models as models


class EmbeddingNet(nn.Module):
    def __init__(self, backbone=None):
7
        super().__init__()
8
9
10
11
12
13
14
15
16
        if backbone is None:
            backbone = models.resnet50(num_classes=128)

        self.backbone = backbone

    def forward(self, x):
        x = self.backbone(x)
        x = nn.functional.normalize(x, dim=1)
        return x