cohere_rerank_client.py 1.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
the Cohere SDK: https://github.com/cohere-ai/cohere-python
Reid's avatar
Reid committed
6
Note that `pip install cohere` is needed to run this example.
7
8
9

run: vllm serve BAAI/bge-reranker-base
"""
10

11
import cohere
Reid's avatar
Reid committed
12
13
14
15
16
17
18
from cohere import Client, ClientV2

model = "BAAI/bge-reranker-base"

query = "What is the capital of France?"

documents = [
19
20
21
    "The capital of France is Paris",
    "Reranking is fun!",
    "vLLM is an open-source framework for fast AI serving",
Reid's avatar
Reid committed
22
23
24
]


25
def cohere_rerank(
26
    client: Client | ClientV2, model: str, query: str, documents: list[str]
27
) -> dict:
Reid's avatar
Reid committed
28
29
30
31
32
    return client.rerank(model=model, query=query, documents=documents)


def main():
    # cohere v1 client
33
    cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
Reid's avatar
Reid committed
34
35
36
37
38
39
    rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
    print("-" * 50)
    print("rerank_v1_result:\n", rerank_v1_result)
    print("-" * 50)

    # or the v2
40
    cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
Reid's avatar
Reid committed
41
42
43
44
    rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
    print("rerank_v2_result:\n", rerank_v2_result)
    print("-" * 50)

45

Reid's avatar
Reid committed
46
47
if __name__ == "__main__":
    main()