from torch import Tensor from torch import nn from typing import Dict import torch.nn.functional as F class Normalize(nn.Module): """ This layer normalizes embeddings to unit length """ def __init__(self): super(Normalize, self).__init__() def forward(self, features: Dict[str, Tensor]): features.update({"sentence_embedding": F.normalize(features["sentence_embedding"], p=2, dim=1)}) return features def save(self, output_path): pass @staticmethod def load(input_path): return Normalize()