Unverified Commit 5ae5ed1e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Consolidate prompt arguments to LLM engines (#4328)


Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 290f4ada
...@@ -63,9 +63,9 @@ steps: ...@@ -63,9 +63,9 @@ steps:
mirror_hardwares: [amd] mirror_hardwares: [amd]
commands: commands:
# these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s test_inputs.py
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - pytest -v -s entrypoints -m llm
- pytest -v -s entrypoints/test_server_oot_registration.py - pytest -v -s entrypoints -m openai
- label: Examples Test - label: Examples Test
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
...@@ -110,6 +110,9 @@ steps: ...@@ -110,6 +110,9 @@ steps:
mirror_hardwares: [amd] mirror_hardwares: [amd]
command: pytest -v -s test_logits_processor.py command: pytest -v -s test_logits_processor.py
- label: Utils Test
command: pytest -v -s test_utils.py
- label: Worker Test - label: Worker Test
mirror_hardwares: [amd] mirror_hardwares: [amd]
command: pytest -v -s worker command: pytest -v -s worker
......
...@@ -3,13 +3,14 @@ import argparse ...@@ -3,13 +3,14 @@ import argparse
import json import json
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
...@@ -48,7 +49,9 @@ def main(args: argparse.Namespace): ...@@ -48,7 +49,9 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
...@@ -59,13 +62,13 @@ def main(args: argparse.Namespace): ...@@ -59,13 +62,13 @@ def main(args: argparse.Namespace):
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p: str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
print(p.key_averages()) print(p.key_averages())
else: else:
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()
......
LLM Class LLM Class
========== =========
.. autoclass:: vllm.LLM .. autoclass:: vllm.LLM
:members: :members:
......
LLM Inputs
==========
.. autodata:: vllm.inputs.PromptStrictInputs
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
:members:
:member-order: bysource
.. autoclass:: vllm.inputs.TokensPrompt
:show-inheritance:
:members:
:member-order: bysource
Offline Inference
=================================
.. toctree::
:maxdepth: 1
llm
llm_inputs
...@@ -68,13 +68,6 @@ Documentation ...@@ -68,13 +68,6 @@ Documentation
getting_started/quickstart getting_started/quickstart
getting_started/examples/examples_index getting_started/examples/examples_index
.. toctree::
:maxdepth: 1
:caption: Offline Inference
offline_inference/llm
offline_inference/sampling_params
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Serving :caption: Serving
...@@ -109,6 +102,8 @@ Documentation ...@@ -109,6 +102,8 @@ Documentation
:maxdepth: 2 :maxdepth: 2
:caption: Developer Documentation :caption: Developer Documentation
dev/sampling_params
dev/offline_inference/offline_index
dev/engine/engine_index dev/engine/engine_index
dev/kernel/paged_attention dev/kernel/paged_attention
dev/dockerfile/dockerfile dev/dockerfile/dockerfile
......
...@@ -48,7 +48,7 @@ completion = client.chat.completions.create( ...@@ -48,7 +48,7 @@ completion = client.chat.completions.create(
``` ```
### Extra Parameters for Chat API ### Extra Parameters for Chat API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python :language: python
...@@ -65,7 +65,7 @@ The following extra parameters are supported: ...@@ -65,7 +65,7 @@ The following extra parameters are supported:
``` ```
### Extra Parameters for Completions API ### Extra Parameters for Completions API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python :language: python
......
...@@ -23,11 +23,15 @@ def run_llava_pixel_values(): ...@@ -23,11 +23,15 @@ def run_llava_pixel_values():
"\nUSER: What is the content of this image?\nASSISTANT:") "\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component. # This should be provided by another online or offline component.
images = torch.load("images/stop_sign_pixel_values.pt") image = torch.load("images/stop_sign_pixel_values.pt")
outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})
outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
...@@ -46,11 +50,14 @@ def run_llava_image_features(): ...@@ -46,11 +50,14 @@ def run_llava_image_features():
"\nUSER: What is the content of this image?\nASSISTANT:") "\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component. # This should be provided by another online or offline component.
images = torch.load("images/stop_sign_image_features.pt") image = torch.load("images/stop_sign_image_features.pt")
outputs = llm.generate(prompt, outputs = llm.generate({
multi_modal_data=MultiModalData( "prompt":
type=MultiModalData.Type.IMAGE, data=images)) prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
......
...@@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" ...@@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort] [tool.isort]
use_parentheses = true use_parentheses = true
skip_gitignore = true skip_gitignore = true
[tool.pytest.ini_options]
markers = [
"skip_global_cleanup",
"llm: run tests for vLLM API only",
"openai: run tests for OpenAI API only",
]
...@@ -25,7 +25,7 @@ class MockEngine: ...@@ -25,7 +25,7 @@ class MockEngine:
return [RequestOutput( return [RequestOutput(
request_id=self.request_id)] if self.request_id else [] request_id=self.request_id)] if self.request_id else []
async def encode_request_async(self, *args, **kwargs): async def process_model_inputs_async(self, *args, **kwargs):
pass pass
def generate(self, request_id): def generate(self, request_id):
......
...@@ -29,7 +29,7 @@ def server(): ...@@ -29,7 +29,7 @@ def server():
ray.shutdown() ray.shutdown()
@pytest.fixture(scope="session") @pytest.fixture(scope="module")
def client(): def client():
client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1", base_url="http://localhost:8000/v1",
......
...@@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, ...@@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
...@@ -402,12 +403,22 @@ class VllmRunner: ...@@ -402,12 +403,22 @@ class VllmRunner:
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
if images is not None: if images is not None:
assert len(prompts) == images.shape[0] assert len(prompts) == images.shape[0]
req_outputs = self.model.generate(
prompts, prompt_inputs: List[PromptInputs] = []
sampling_params=sampling_params, for i, prompt in enumerate(prompts):
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, image = None if images is None else images[i:i + 1]
data=images) mm_data = None if image is None else MultiModalData(
if images is not None else None) type=MultiModalData.Type.IMAGE,
data=image,
)
prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})
req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)
outputs = [] outputs = []
for req_output in req_outputs: for req_output in req_outputs:
prompt_str = req_output.prompt prompt_str = req_output.prompt
......
...@@ -133,8 +133,11 @@ def test_append_slot_cow(): ...@@ -133,8 +133,11 @@ def test_append_slot_cow():
# Allocate prompt to gpu block. There is one slot left in the block. # Allocate prompt to gpu block. There is one slot left in the block.
prompt = Sequence(seq_id=1, prompt = Sequence(seq_id=1,
prompt="one two three", inputs={
prompt_token_ids=[1, 2, 3], "prompt": "one two three",
"prompt_token_ids": [1, 2, 3],
"multi_modal_data": None
},
block_size=block_size) block_size=block_size)
# Fork the sequence, such that a COW will be required when we append a new # Fork the sequence, such that a COW will be required when we append a new
...@@ -304,7 +307,13 @@ def test_sliding_window_multi_seq(): ...@@ -304,7 +307,13 @@ def test_sliding_window_multi_seq():
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
parent = Sequence(1, "one two three", [0, 1, 2], block_size) parent = Sequence(seq_id=1,
inputs={
"prompt": "one two three",
"prompt_token_ids": [0, 1, 2],
"multi_modal_data": None
},
block_size=block_size)
seq_group = SequenceGroup(request_id="1", seq_group = SequenceGroup(request_id="1",
seqs=[parent], seqs=[parent],
arrival_time=time.time(), arrival_time=time.time(),
......
...@@ -21,7 +21,13 @@ def create_dummy_prompt( ...@@ -21,7 +21,13 @@ def create_dummy_prompt(
# and prompt "0 ... block_size". # and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length)) prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt], seqs=[prompt],
arrival_time=time.time(), arrival_time=time.time(),
...@@ -51,8 +57,11 @@ def create_seq_group( ...@@ -51,8 +57,11 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens): for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence( seq = Sequence(
seq_id=seq_id_start + seq_id_offset, seq_id=seq_id_start + seq_id_offset,
prompt="", inputs={
prompt_token_ids=prompt_token_ids, "prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16, block_size=16,
) )
......
...@@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str): ...@@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str):
with pytest.raises(ValueError) as err: with pytest.raises(ValueError) as err:
llm.generate("abc", sampling_params) llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value) assert "prompts must be None if" in str(err.value)
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]], outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params) sampling_params=sampling_params)
assert len(outputs) > 0 assert len(outputs) > 0
completions = outputs[0].outputs completions = outputs[0].outputs
......
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import pytest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
pytestmark = pytest.mark.openai
@dataclass @dataclass
class MockModelConfig: class MockModelConfig:
......
...@@ -52,6 +52,8 @@ TEST_SCHEMA = { ...@@ -52,6 +52,8 @@ TEST_SCHEMA = {
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
pytestmark = pytest.mark.openai
def test_guided_logits_processors(): def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
......
import weakref
from typing import List
import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from ..conftest import cleanup
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
TOKEN_IDS = [
# Using ID={0, 1, 2, 3} results in NaN values,
# so we add this offset of 1000
[1000],
[1000, 1001],
[1000, 1002, 1001],
[1000, 1003, 1001, 1002],
]
pytestmark = pytest.mark.llm
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=32768,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
o2: List[EmbeddingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
v2_output = llm.encode(prompt, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
prompt_token_ids):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.encode(prompt_token_ids=prompt_token_ids,
pooling_params=pooling_params)
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode(
[{
"prompt": p
} for p in PROMPTS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.encode(prompt_token_ids=TOKEN_IDS,
pooling_params=pooling_params)
v2_output = llm.encode(
[{
"prompt_token_ids": p
} for p in TOKEN_IDS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_multiple_pooling_params(llm: LLM):
pooling_params = [
PoolingParams(),
PoolingParams(),
PoolingParams(),
PoolingParams(),
]
# Multiple PoolingParams should be matched with each prompt
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
# Single PoolingParams should be applied to every prompt
single_pooling_params = PoolingParams()
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
assert len(PROMPTS) == len(outputs)
# pooling_params is None, default params should be applied
outputs = llm.encode(PROMPTS, pooling_params=None)
assert len(PROMPTS) == len(outputs)
import pytest import weakref
from typing import List
from vllm import LLM, SamplingParams import pytest
from vllm import LLM, RequestOutput, SamplingParams
def test_multiple_sampling_params(): from ..conftest import cleanup
llm = LLM(model="facebook/opt-125m", MODEL_NAME = "facebook/opt-125m"
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = [ PROMPTS = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ]
TOKEN_IDS = [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]
pytestmark = pytest.mark.llm
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=prompt,
sampling_params=sampling_params)
v2_output = llm.generate(prompt, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate({"prompt": prompt},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
prompt_token_ids):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)
v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=PROMPTS,
sampling_params=sampling_params)
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate(
[{
"prompt": p
} for p in PROMPTS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
sampling_params=sampling_params)
v2_output = llm.generate(
[{
"prompt_token_ids": p
} for p in TOKEN_IDS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_multiple_sampling_params(llm: LLM):
sampling_params = [ sampling_params = [
SamplingParams(temperature=0.01, top_p=0.95), SamplingParams(temperature=0.01, top_p=0.95),
SamplingParams(temperature=0.3, top_p=0.95), SamplingParams(temperature=0.3, top_p=0.95),
...@@ -24,18 +127,18 @@ def test_multiple_sampling_params(): ...@@ -24,18 +127,18 @@ def test_multiple_sampling_params():
] ]
# Multiple SamplingParams should be matched with each prompt # Multiple SamplingParams should be matched with each prompt
outputs = llm.generate(prompts, sampling_params=sampling_params) outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
assert len(prompts) == len(outputs) assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts # Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError): with pytest.raises(ValueError):
outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
# Single SamplingParams should be applied to every prompt # Single SamplingParams should be applied to every prompt
single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
outputs = llm.generate(prompts, sampling_params=single_sampling_params) outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
assert len(prompts) == len(outputs) assert len(PROMPTS) == len(outputs)
# sampling_params is None, default params should be applied # sampling_params is None, default params should be applied
outputs = llm.generate(prompts, sampling_params=None) outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(prompts) == len(outputs) assert len(PROMPTS) == len(outputs)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment