ours2st.py 2.75 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from typing import Optional, List
from dataclasses import dataclass, field
from sentence_transformers import models, SentenceTransformer
from transformers import HfArgumentParser


def convert_ours_ckpt_to_sentence_transformer(src_dir, dest_dir, pooling_method: List[str] = ['cls'], dense_metric: str="cos"):
    assert os.path.exists(src_dir), f"Make sure the encoder path {src_dir} is valid on disk!"
    assert "decoder" not in pooling_method, f"Pooling method 'decode' cannot be saved as sentence_transformers because it uses the decoder stack to produce sentence embedding."
    if dest_dir is None:
        dest_dir = src_dir

    print(f"loading model from {src_dir} and saving the sentence_transformer model at {dest_dir}...")

    word_embedding_model = models.Transformer(src_dir)
    modules = [word_embedding_model]
    ndim = word_embedding_model.get_word_embedding_dimension()

    if "cls" in pooling_method:
        pooling_model = models.Pooling(ndim, pooling_mode="cls")
        pooling_method.remove("cls")
    elif "mean" in pooling_method:
        pooling_model = models.Pooling(ndim, pooling_mode="mean")
        pooling_method.remove("mean")
    else:
        raise NotImplementedError(f"Fail to find cls or mean in pooling_method {pooling_method}!")
    
    modules.append(pooling_model)

    if "dense" in pooling_method:
        modules.append(models.Dense(ndim, ndim, bias=False))
        pooling_method.remove("dense")
    
    assert len(pooling_method) == 0, f"Found unused pooling_method {pooling_method}!"

    if dense_metric == "cos":
        normalize_layer = models.Normalize()
        modules.append(normalize_layer)

    model = SentenceTransformer(modules=modules, device='cpu')
    model.save(dest_dir)


@dataclass
class Args:
    encoder: Optional[str] = field(
        default=None,
        metadata={'help': 'Path to the encoder model.'}
    )
    output_dir: Optional[str] = field(
        default=None,
        metadata={'help': 'Path to the output sentence_transformer model.'}
    )
    pooling_method: List[str] = field(
        default_factory=lambda: ["cls"],
        metadata={'help': 'Pooling methods to aggregate token embeddings for a sequence embedding. {cls, mean, dense, decoder}'}
    )
    dense_metric: str = field(
        default="cos",
        metadata={'help': 'What type of metric for dense retrieval? ip, l2, or cos.'}
    )
    model_cache_dir: Optional[str] = field(
        default=None,
        metadata={'help': 'Cache folder for huggingface transformers.'}
    )

    def __post_init__(self):
        convert_ours_ckpt_to_sentence_transformer(self.encoder, self.output_dir, self.pooling_method, self.dense_metric)

if __name__ == "__main__":
    parser = HfArgumentParser([Args])
    args, = parser.parse_args_into_dataclasses()