"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "1a956e136beae057746af6257ffa8da601730f10"
Unverified Commit 81fbb365 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[CI/Build] Test both text and token IDs in batched OpenAI Completions API (#5568)

parent 0e9164b4
...@@ -655,50 +655,52 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -655,50 +655,52 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test simple list # test both text and token IDs
batch = await client.completions.create( for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
model=model_name, # test simple list
prompt=["Hello, my name is", "Hello, my name is"], batch = await client.completions.create(
max_tokens=5, model=model_name,
temperature=0.0, prompt=prompts,
) max_tokens=5,
assert len(batch.choices) == 2 temperature=0.0,
assert batch.choices[0].text == batch.choices[1].text )
assert len(batch.choices) == 2
# test n = 2 assert batch.choices[0].text == batch.choices[1].text
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming # test n = 2
batch = await client.completions.create( batch = await client.completions.create(
model=model_name, model=model_name,
prompt=["Hello, my name is", "Hello, my name is"], prompt=prompts,
max_tokens=5, n=2,
temperature=0.0, max_tokens=5,
stream=True, temperature=0.0,
) extra_body=dict(
texts = [""] * 2 # NOTE: this has to be true for n > 1 in vLLM, but not necessary
async for chunk in batch: # for official client.
assert len(chunk.choices) == 1 use_beam_search=True),
choice = chunk.choices[0] )
texts[choice.index] += choice.text assert len(batch.choices) == 4
assert texts[0] == texts[1] assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
@pytest.mark.asyncio @pytest.mark.asyncio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment