semantic_search.py 2.06 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
"""
This is a simple application for sentence embeddings: semantic search

We have a corpus with various sentences. Then, for a given query sentence,
we want to find the most similar sentence in this corpus.

This script outputs for various queries the top 5 most similar sentences in the corpus.
"""

import torch

Rayyyyy's avatar
Rayyyyy committed
12
13
from sentence_transformers import SentenceTransformer

Rayyyyy's avatar
Rayyyyy committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# Corpus with example sentences
corpus = [
    "A man is eating food.",
    "A man is eating a piece of bread.",
    "The girl is carrying a baby.",
    "A man is riding a horse.",
    "A woman is playing violin.",
    "Two men pushed carts through the woods.",
    "A man is riding a white horse on an enclosed ground.",
    "A monkey is playing drums.",
    "A cheetah is running behind its prey.",
]
Rayyyyy's avatar
Rayyyyy committed
28
# Use "convert_to_tensor=True" to keep the tensors on GPU (if available)
Rayyyyy's avatar
Rayyyyy committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)

# Query sentences:
queries = [
    "A man is eating pasta.",
    "Someone in a gorilla costume is playing a set of drums.",
    "A cheetah chases prey on across a field.",
]

# Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity
top_k = min(5, len(corpus))
for query in queries:
    query_embedding = embedder.encode(query, convert_to_tensor=True)

    # We use cosine-similarity and torch.topk to find the highest 5 scores
Rayyyyy's avatar
Rayyyyy committed
44
45
    similarity_scores = embedder.similarity(query_embedding, corpus_embeddings)[0]
    scores, indices = torch.topk(similarity_scores, k=top_k)
Rayyyyy's avatar
Rayyyyy committed
46

Rayyyyy's avatar
Rayyyyy committed
47
48
    print("\nQuery:", query)
    print("Top 5 most similar sentences in corpus:")
Rayyyyy's avatar
Rayyyyy committed
49

Rayyyyy's avatar
Rayyyyy committed
50
    for score, idx in zip(scores, indices):
Rayyyyy's avatar
Rayyyyy committed
51
52
53
54
55
56
57
58
59
        print(corpus[idx], "(Score: {:.4f})".format(score))

    """
    # Alternatively, we can also use util.semantic_search to perform cosine similarty + topk
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=5)
    hits = hits[0]      #Get the hits for the first query
    for hit in hits:
        print(corpus[hit['corpus_id']], "(Score: {:.4f})".format(hit['score']))
    """