Unverified Commit dcf8862f authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Examples][1/n] Resettle basic examples. (#35579)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 43aa3892
...@@ -5,6 +5,7 @@ from argparse import Namespace ...@@ -5,6 +5,7 @@ from argparse import Namespace
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.print_utils import print_embeddings
def parse_args(): def parse_args():
...@@ -41,10 +42,8 @@ def main(args: Namespace): ...@@ -41,10 +42,8 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60) print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
rewards = output.outputs.data rewards = output.outputs.data
rewards_trimmed = ( print(f"Prompt: {prompt!r}")
(str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards print_embeddings(rewards, prefix="Reward")
)
print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})")
print("-" * 60) print("-" * 60)
......
...@@ -17,7 +17,7 @@ def test_platform_plugins(): ...@@ -17,7 +17,7 @@ def test_platform_plugins():
example_file = os.path.join( example_file = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(current_file))), os.path.dirname(os.path.dirname(os.path.dirname(current_file))),
"examples", "examples",
"offline_inference/basic/basic.py", "basic/offline_inference/basic.py",
) )
runpy.run_path(example_file) runpy.run_path(example_file)
......
...@@ -2,6 +2,6 @@ ...@@ -2,6 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def print_embeddings(embeds: list[float]): def print_embeddings(embeds: list[float], prefix: str = "Embeddings"):
embeds_trimmed = (str(embeds[:4])[:-1] + ", ...]") if len(embeds) > 4 else embeds embeds_trimmed = (str(embeds[:4])[:-1] + ", ...]") if len(embeds) > 4 else embeds
print(f"Embeddings: {embeds_trimmed} (size={len(embeds)})") print(f"{prefix}: {embeds_trimmed} (size={len(embeds)})")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment