Asym.py 5.49 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
import json
Rayyyyy's avatar
Rayyyyy committed
2
import os
Rayyyyy's avatar
Rayyyyy committed
3
from collections import OrderedDict
Rayyyyy's avatar
Rayyyyy committed
4
5
6
7
8
from typing import Dict, List, Tuple, Union

from torch import Tensor, nn

from sentence_transformers.util import import_from_string
Rayyyyy's avatar
Rayyyyy committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


class Asym(nn.Sequential):
    def __init__(self, sub_modules: Dict[str, List[nn.Module]], allow_empty_key: bool = True):
        """
        This model allows to create asymmetric SentenceTransformer models, that apply different models depending on the specified input key.

        In the below example, we create two different Dense models for 'query' and 'doc'. Text that is passed as {'query': 'My query'} will
        be passed along along the first Dense model, and text that will be passed as {'doc': 'My document'} will use the other Dense model.

        Note, that when you call encode(), that only inputs of the same type can be encoded. Mixed-Types cannot be encoded.

        Example::
            word_embedding_model = models.Transformer(model_name)
            pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
            asym_model = models.Asym({'query': [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)], 'doc': [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)]})
            model = SentenceTransformer(modules=[word_embedding_model, pooling_model, asym_model])

            model.encode([{'query': 'Q1'}, {'query': 'Q2'}]
            model.encode([{'doc': 'Doc1'}, {'doc': 'Doc2'}]

            #You can train it with InputExample like this. Note, that the order must always be the same:
            train_example = InputExample(texts=[{'query': 'Train query'}, {'doc': 'Document'}], label=1)

Rayyyyy's avatar
Rayyyyy committed
33
34
35
36
37
38
39
        Args:
            sub_modules: Dict in the format str -> List[models]. The
                models in the specified list will be applied for input
                marked with the respective key.
            allow_empty_key: If true, inputs without a key can be
                processed. If false, an exception will be thrown if no
                key is specified.
Rayyyyy's avatar
Rayyyyy committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        """
        self.sub_modules = sub_modules
        self.allow_empty_key = allow_empty_key

        ordered_dict = OrderedDict()
        for name, models in sub_modules.items():
            if not isinstance(models, List):
                models = [models]

            for idx, model in enumerate(models):
                ordered_dict[name + "-" + str(idx)] = model
        super(Asym, self).__init__(ordered_dict)

    def forward(self, features: Dict[str, Tensor]):
        if "text_keys" in features and len(features["text_keys"]) > 0:
            text_key = features["text_keys"][0]
            for model in self.sub_modules[text_key]:
                features = model(features)
        elif not self.allow_empty_key:
            raise ValueError("Input did not specify any keys and allow_empty_key is False")

        return features

    def get_sentence_embedding_dimension(self) -> int:
        for name in self.sub_modules:
            if hasattr(self.sub_modules[name][0], "get_sentence_embedding_dimension"):
                return self.sub_modules[name][0].get_sentence_embedding_dimension()
        return None

    def save(self, output_path):
        model_lookup = {}
        model_types = {}
        model_structure = {}

        for name, models in self.sub_modules.items():
            model_structure[name] = []
            for model in models:
                model_id = str(id(model)) + "_" + type(model).__name__
                model_lookup[model_id] = model
                model_types[model_id] = type(model).__module__
                model_structure[name].append(model_id)

        for model_id, model in model_lookup.items():
            model_path = os.path.join(output_path, str(model_id))
            os.makedirs(model_path, exist_ok=True)
            model.save(model_path)

        with open(os.path.join(output_path, "config.json"), "w", encoding="utf8") as fOut:
            json.dump(
                {
                    "types": model_types,
                    "structure": model_structure,
                    "parameters": {"allow_empty_key": self.allow_empty_key},
                },
                fOut,
                indent=2,
            )

    def tokenize(self, texts: Union[List[str], List[Tuple[str, str]]], **kwargs):
Rayyyyy's avatar
Rayyyyy committed
99
        """Tokenizes a text and maps tokens to token-ids"""
Rayyyyy's avatar
Rayyyyy committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        if not isinstance(texts[0], dict):
            raise AttributeError("Asym. model requires that texts are passed as dicts: {'key': 'text'}")

        module_key = None

        for lookup in texts:
            text_key, text = next(iter(lookup.items()))
            if module_key is None:
                module_key = text_key

            assert text_key == module_key  # Mixed batches are not allowed
        return self.sub_modules[module_key][0].tokenize(texts, **kwargs)

    @staticmethod
    def load(input_path):
        with open(os.path.join(input_path, "config.json")) as fIn:
            config = json.load(fIn)

        modules = {}
        for model_id, model_type in config["types"].items():
            module_class = import_from_string(model_type)
            module = module_class.load(os.path.join(input_path, model_id))
            modules[model_id] = module

        model_structure = {}
        for key_name, models_list in config["structure"].items():
            model_structure[key_name] = []
            for model_id in models_list:
                model_structure[key_name].append(modules[model_id])

        model = Asym(model_structure, **config["parameters"])
        return model