embedding.py 814 Bytes
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
from vllm import LLM

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create an LLM.
14
# You should pass task="embed" for embedding models
15
16
model = LLM(
    model="intfloat/e5-mistral-7b-instruct",
17
    task="embed",
18
19
20
    enforce_eager=True,
)

21
22
23
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.embed(prompts)

24
# Print the outputs.
25
26
27
28
29
30
for prompt, output in zip(prompts, outputs):
    embeds = output.outputs.embedding
    embeds_trimmed = ((str(embeds[:16])[:-1] +
                       ", ...]") if len(embeds) > 16 else embeds)
    print(f"Prompt: {prompt!r} | "
          f"Embeddings: {embeds_trimmed} (size={len(embeds)})")