embedding.py 775 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
from vllm import LLM

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create an LLM.
12
# You should pass task="embed" for embedding models
13
14
model = LLM(
    model="intfloat/e5-mistral-7b-instruct",
15
    task="embed",
16
17
18
    enforce_eager=True,
)

19
20
21
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.embed(prompts)

22
# Print the outputs.
23
24
25
26
27
28
for prompt, output in zip(prompts, outputs):
    embeds = output.outputs.embedding
    embeds_trimmed = ((str(embeds[:16])[:-1] +
                       ", ...]") if len(embeds) > 16 else embeds)
    print(f"Prompt: {prompt!r} | "
          f"Embeddings: {embeds_trimmed} (size={len(embeds)})")