Unverified Commit d92b3c5c authored by Joe's avatar Joe Committed by GitHub
Browse files

[Bugfix][CI/Build] Test prompt adapters in openai entrypoint tests (#6419)

parent 9ad32dac
...@@ -17,9 +17,13 @@ from ...utils import RemoteOpenAIServer ...@@ -17,9 +17,13 @@ from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing # technically these adapters use a different base model,
# generation quality here # but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora" LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -28,7 +32,12 @@ def zephyr_lora_files(): ...@@ -28,7 +32,12 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files): def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_pa_files):
with RemoteOpenAIServer([ with RemoteOpenAIServer([
"--model", "--model",
MODEL_NAME, MODEL_NAME,
...@@ -37,8 +46,10 @@ def server(zephyr_lora_files): ...@@ -37,8 +46,10 @@ def server(zephyr_lora_files):
"bfloat16", "bfloat16",
"--max-model-len", "--max-model-len",
"8192", "8192",
"--max-num-seqs",
"128",
"--enforce-eager", "--enforce-eager",
# lora config below # lora config
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"zephyr-lora={zephyr_lora_files}", f"zephyr-lora={zephyr_lora_files}",
...@@ -47,7 +58,14 @@ def server(zephyr_lora_files): ...@@ -47,7 +58,14 @@ def server(zephyr_lora_files):
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
"2", "2",
"--max-num-seqs", # pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128", "128",
]) as remote_server: ]) as remote_server:
yield remote_server yield remote_server
...@@ -60,11 +78,14 @@ def client(server): ...@@ -60,11 +78,14 @@ def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras, then test prompt adapters
"model_name", "model_name,num_virtual_tokens",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
) )
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
completion = await client.completions.create(model=model_name, completion = await client.completions.create(model=model_name,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
...@@ -77,28 +98,30 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): ...@@ -77,28 +98,30 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
assert len(choice.text) >= 5 assert len(choice.text) >= 5
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage( assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11) completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=[0, 0, 0, 0, 0], prompt=[0, 0, 0, 0, 0],
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
) )
assert len(completion.choices[0].text) >= 5 assert len(completion.choices[0].text) >= 1
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras, then test prompt adapters
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
) )
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=[0, 0, 0, 0, 0], prompt=[0, 0, 0, 0, 0],
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
...@@ -110,14 +133,14 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -110,14 +133,14 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora and 1 pa hereafter
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
) )
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=[0, 0, 0, 0, 0], prompt=[0, 0, 0, 0, 0],
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
...@@ -133,12 +156,12 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -133,12 +156,12 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
) )
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=[0, 0, 0, 0, 0], prompt=[0, 0, 0, 0, 0],
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
...@@ -154,7 +177,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): ...@@ -154,7 +177,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
) )
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -162,7 +185,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, ...@@ -162,7 +185,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
with pytest.raises( with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs (openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create( await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=[0, 0, 0, 0, 0], prompt=[0, 0, 0, 0, 0],
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
...@@ -174,7 +197,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, ...@@ -174,7 +197,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
with pytest.raises( with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs (openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create( stream = await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=[0, 0, 0, 0, 0], prompt=[0, 0, 0, 0, 0],
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
...@@ -199,7 +222,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, ...@@ -199,7 +222,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
) )
async def test_completion_streaming(client: openai.AsyncOpenAI, async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -233,7 +256,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, ...@@ -233,7 +256,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
) )
async def test_completion_stream_options(client: openai.AsyncOpenAI, async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
...@@ -369,9 +392,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -369,9 +392,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
) )
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs # test both text and token IDs
...@@ -623,7 +645,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, ...@@ -623,7 +645,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
) )
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/") base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
for add_special in [False, True]: for add_special in [False, True]:
prompt = "This is a test prompt." prompt = "This is a test prompt."
...@@ -650,7 +672,7 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): ...@@ -650,7 +672,7 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
) )
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3] base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
prompt = "This is a test prompt." prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False) tokens = tokenizer.encode(prompt, add_special_tokens=False)
......
import json import json
import pathlib
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -74,8 +75,8 @@ class OpenAIServing: ...@@ -74,8 +75,8 @@ class OpenAIServing:
self.prompt_adapter_requests = [] self.prompt_adapter_requests = []
if prompt_adapters is not None: if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1): for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with open(f"./{prompt_adapter.local_path}" with pathlib.Path(prompt_adapter.local_path,
f"/adapter_config.json") as f: "adapter_config.json").open() as f:
adapter_config = json.load(f) adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"] num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append( self.prompt_adapter_requests.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment