Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
from typing import List, Optional, Tuple, Type, overload
import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....utils import multi_gpu_test
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 1
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
"cherry_blossom":
"<|image|><|begin_of_text|>The city is",
})
text_only_prompts = [
"The color of the sky is blue but sometimes it can also be",
]
models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_num_seqs=16,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
def process(hf_inputs: BatchEncoding):
return hf_inputs
from transformers import AutoConfig
from transformers.models.mllama import MllamaConfig as MllamaConfigHf
# use transformer's MllamaConfig for hf_runner
# and vllm's MllamaConfig for vllm_runner
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
from vllm.transformers_utils.configs.mllama import MllamaConfig
AutoConfig.register("mllama", MllamaConfig, exist_ok=True)
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=2,
)
......@@ -6,7 +6,8 @@ from vllm.model_executor.models import _MODELS, ModelRegistry
@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
if (model_cls == "Qwen2VLForConditionalGeneration"
if (model_cls in ("LlavaOnevisionForConditionalGeneration",
"Qwen2VLForConditionalGeneration")
and transformers.__version__ < "4.45"):
pytest.skip("Waiting for next transformers release")
......
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import Logprob, SampleLogprobs
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
TokensText = Tuple[List[int], str]
......@@ -34,20 +36,47 @@ def check_outputs_equal(
assert output_ids_0 == output_ids_1, fail_msg
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]
# Allow for tokens to be represented as str's rather than IDs
# Allow for tokens to be represented as str's rather than IDs;
# tuple of
# * Token string representations list
# * String
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = Tuple[
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
def check_logprobs_close(
*,
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
......@@ -57,6 +86,18 @@ def check_logprobs_close(
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
How sample logprobs are compared:
* `always_check_logprobs == True`: set of highest-logprob token ids
must match between seq0 and seq1 at all sampled token offsets
* `always_check_logprobs == False`: highest-logprob token ids are
only compared at sampled token offsets for which generated token
ids don't match
Prompt logprobs must be provided either for both input sequences, or
for neither. If prompt logprobs are provided, then highest-logprob
prompt token ids must match between seq0 and seq1 at all prompt token
offsets.
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
......@@ -78,8 +119,65 @@ def check_logprobs_close(
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
assert len(outputs_0) == len(outputs_1)
if len(outputs_0) == 3:
assert len(outputs_1) == 3
# Break out tokens, text & sample logprobs
# (prompt logprobs were not provided)
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
elif len(outputs_0) == 4:
assert len(outputs_1) == 4
# Break out tokens, text, sample logprobs & prompt logprobs
(
output_ids_0,
output_str_0,
logprobs_0,
prompt_logprobs_0,
) = outputs_0
(
output_ids_1,
output_str_1,
logprobs_1,
prompt_logprobs_1,
) = outputs_1
# Test prompt logprobs closeness
if (prompt_logprobs_0 is not None
and prompt_logprobs_1 is not None):
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
zip(prompt_logprobs_0, prompt_logprobs_1)):
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
if logprobs_elem_0 is None:
# If the seq 0 token's logprobs are `None`,
# the seq 1 token's logprobs must be `None`
assert logprobs_elem_1 is None, fail_msg
else:
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
else:
# Both sequence logprobs lists must be `None`
fail_msg = (f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
assert (prompt_logprobs_0 is None
and prompt_logprobs_1 is None), fail_msg
else:
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}")
if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
......@@ -144,3 +242,36 @@ def check_logprobs_close(
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
def build_model_context(model_name: str,
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False,
mm_processor_kwargs: Optional[Dict] = None,
limit_mm_per_prompt: Optional[Dict] = None):
"""Creates an InputContext for a given model.
Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
trust_remote_code: Whether or not to allow loading remote code.
mm_processor_kwargs: optional processor kwargs for to be leveraged
in the input processor, mapper, dummy data creation, etc.
limit_mm_per_prompt: Multimodal limits.
Returns:
InputContext for the model being considered.
"""
if tokenizer_name is None:
tokenizer_name = model_name
model_config = ModelConfig(
model_name,
tokenizer_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
dtype="float32",
seed=0,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt,
)
return InputContext(model_config)
"""Test that aborting is handled properly."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
EXPECTED_TOKENS = 250
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_id_to_be_aborted = "request-aborted"
request_ids_a = [f"request-a-{idx}" for idx in range(10)]
request_ids_b = [f"request-b-{idx}" for idx in range(10)]
# Requests started before one to be aborted.
tasks = []
for request_id in request_ids_a:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Aborted.
task_aborted = asyncio.create_task(
generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
# Requests started after one to be aborted.
for request_id in request_ids_b:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Actually abort.
await asyncio.sleep(0.5)
await client.abort(request_id_to_be_aborted)
# Confirm that we got all the EXPECTED tokens from the requests.
for task in tasks:
count, request_id = await task
assert count == EXPECTED_TOKENS, (
f"{request_id} generated only {count} tokens")
# Cancel task (this will hang indefinitely if not).
task_aborted.cancel()
# Shutdown.
client.close()
"""Test that various errors are handled properly."""
import asyncio
import tempfile
import time
import uuid
from unittest.mock import Mock
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.execute_model = Mock(
side_effect=RAISED_ERROR(RAISED_VALUE))
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_evil_forward(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_forward) as engine:
client = await engine.make_client()
# Server should be healthy after initial probe.
await asyncio.sleep(2.0)
await client.check_health()
# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored
# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored
await asyncio.sleep(1.0)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Shutdown.
client.close()
def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_health_check(tmp_socket):
with RemoteMQLLMEngine(
engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_model_executor_health) as engine:
client = await engine.make_client()
assert client.is_running
# Health probe should throw RAISED_ERROR.
await asyncio.sleep(15.)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
client.close()
def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during abort call.
engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:
client = await engine.make_client()
assert client.is_running
# Firsh check health should work.
await client.check_health()
# Trigger an abort on the client side.
# This request ID does not exist, and will cause the engine to error
await client.abort(request_id="foo")
# Future generation requests will now fail
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored
# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
client.close()
@pytest.mark.asyncio
async def test_bad_request(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
"invalid-lora", 1,
"invalid-path")):
pass
# This request should be okay.
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
# Shutdown.
client.close()
@pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch):
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
# When LLMEngine is loaded, it will crash.
def mock_init():
raise ValueError
monkeypatch.setattr(LLMEngine, "__init__", mock_init)
start = time.perf_counter()
async with build_async_engine_client(args):
pass
end = time.perf_counter()
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
"if there is an error in the startup.")
@pytest.mark.asyncio
async def test_mp_cuda_init():
# it should not crash, when cuda is initialized
# in the API server process
import torch
torch.cuda.init()
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
async with build_async_engine_client(args):
pass
"""Test that the MQLLMEngine is able to handle 10k concurrent requests."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
NUM_EXPECTED_TOKENS = 10
NUM_REQUESTS = 10000
# Scenarios to test for num generated token.
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True)
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_load(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(client, request_id, NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
failed_request_id = None
tokens = None
for task in tasks:
num_generated_tokens, request_id = await task
if (num_generated_tokens != NUM_EXPECTED_TOKENS
and failed_request_id is None):
failed_request_id = request_id
tokens = num_generated_tokens
assert failed_request_id is None, (
f"{failed_request_id} generated {tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
# Shutdown.
client.close()
import asyncio
import multiprocessing
from typing import Callable, Tuple, Union
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
async def generate(
client: MQLLMEngineClient,
request_id: str,
num_tokens: int,
return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]:
final_output = None
count = 0
async for out in client.generate(
request_id=request_id,
inputs="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):
count += 1
final_output = out
await asyncio.sleep(0.)
if return_output:
return final_output
# Confirm we generated all the tokens we expected.
return count, request_id
def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Run engine.
engine.start()
class RemoteMQLLMEngine:
def __init__(self,
engine_args: AsyncEngineArgs,
ipc_path: str,
run_fn: Callable = run_normal) -> None:
self.engine_args = engine_args
self.ipc_path = ipc_path
context = multiprocessing.get_context("spawn")
self.proc = context.Process(target=run_fn,
args=(engine_args, ipc_path))
self.proc.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.proc.kill()
async def make_client(self) -> MQLLMEngineClient:
engine_config = self.engine_args.create_engine_config()
client = MQLLMEngineClient(self.ipc_path, engine_config)
while True:
try:
await client.setup()
break
except TimeoutError:
assert self.proc.is_alive()
return client
......@@ -100,3 +100,95 @@ def test_multi_step_llm(
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
def test_multi_step_llm_w_prompt_logprobs(
vllm_runner,
example_prompts,
model: str,
dtype: str,
tp_size: int,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
) -> None:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
Set up a vLLM engine instance w/ single-step scheduling as a ground-truth
reference.
Prompt them with the same example prompts.
Validate:
* All generated logprobs are all very close
Args:
hf_runner: HF transformers model runner fixture
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
num_prompt_logprobs: number of logprobs to return for each prompt token;
note that this argument is not supported by the
OpenAI completions endpoint.
"""
prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts
with vllm_runner(
model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs,
num_prompt_logprobs=num_prompt_logprobs)
with vllm_runner(
model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
) as vllm_model:
single_step_vllm_outputs = vllm_model.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs,
num_prompt_logprobs=num_prompt_logprobs)
check_logprobs_close(
outputs_0_lst=single_step_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
......@@ -5,7 +5,7 @@ from vllm.multimodal.base import MultiModalInputs, NestedTensors
def assert_nested_tensors_equal(expected: NestedTensors,
actual: NestedTensors):
assert type(expected) == type(actual)
assert type(expected) == type(actual) # noqa: E721
if isinstance(expected, torch.Tensor):
assert torch.equal(expected, actual)
else:
......
from array import array
from typing import Mapping
from unittest.mock import patch
import pytest
import torch
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs.registry import InputRegistry
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from ..models.utils import build_model_context
# Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID = "facebook/opt-125m"
# Used for tests that need a multimodal model
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
# For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc.
DEFAULT_NUM_CROPS = 4
NUM_CROPS_OVERRIDE = 16
# Mocks for all of the places that we use the mm_processor_kwargs
# to override values in different callables
@pytest.fixture
def use_processor_mock():
"""Patches the internal model input processor with an override callable."""
def custom_processor(ctx: InputContext,
llm_inputs: LLMInputs,
*,
num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return
# type validation, and just return the value of the kwarg that we
# clobber.
return num_crops
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
return_value=custom_processor):
yield
@pytest.fixture
def use_dummy_data_mock():
"""Patches the internal model input processor with an override callable."""
def custom_dummy_data_factory(self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None
with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
custom_dummy_data_factory):
yield
# Lazy import to avoid CUDA reinitialization error
def mm_model_cls():
from vllm.model_executor.models.phi3v import Phi3VForCausalLM
return Phi3VForCausalLM
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
}
### Test for default processor logic & mm_processor_kwargs wrapping
def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_processor_default_kwargs(use_processor_mock, num_crops):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()
# If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == expected_num_crops
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
mm_processor_kwargs):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == DEFAULT_NUM_CROPS
### Test overrides for the dummy data
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
"""Ensure dummy data factories can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
mm_processor_kwargs):
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_max_tokens_kwarg_overrides(num_crops):
"""Ensure max token calcs can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
{mm_model_cls(): get_num_crops},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config)
assert expected_seq_count == max_multimodal_tokens
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Similar before, but since these kwargs get filtered,
# we always get our default value back.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
{mm_model_cls(): get_num_crops},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config)
assert max_multimodal_tokens == DEFAULT_NUM_CROPS
### Test overrides for the mapper
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
"""Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs={"num_crops": num_crops},
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
mm_inputs = {"image": image}
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
"""Ensure custom mappers can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}
with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_processor_kwargs):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}
with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
......@@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='Test requires at least 2 GPUs.')
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@fork_new_process_for_each_test
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:
hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner,
vllm_runner,
example_prompts[:1],
model_name,
hf_model_kwargs,
vllm_tp_size=2)
def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
......@@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner,
vllm_runner,
prompts,
model_name,
hf_model_kwargs=None):
hf_model_kwargs=None,
vllm_tp_size=1):
# NOTE: run vLLM first, as it requires a clean process
# when using distributed inference
#Run with vLLM runner
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True,
tensor_parallel_size=vllm_tp_size,
enforce_eager=False,
gpu_memory_utilization=0.8) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
......@@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner,
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
......
......@@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability >= 89 and not force_marlin:
if current_platform.has_device_capability(89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
......
......@@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool:
return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability())
assert capability is not None
min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
return capability.to_int() >= min_capability
......@@ -9,9 +9,9 @@ import pytest
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [128]
MAX_TOKENS = [64]
BEAM_WIDTHS = [4]
MODELS = ["facebook/opt-125m"]
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
@pytest.mark.parametrize("model", MODELS)
......@@ -33,12 +33,19 @@ def test_beam_search_single_input(
max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
vllm_outputs = vllm_model.generate_beam_search_new(
example_prompts, beam_width, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for i, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{i}-th hf output:")
print(hf_text)
print(f">>>{i}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
......
......@@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int,
disable_bonus_tokens: bool, device: str,
use_flashinfer: bool):
device: str, use_flashinfer: bool):
"""Verify the output has correct format given predetermined accepted matrix.
"""
if use_flashinfer and disable_bonus_tokens:
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
set_random_seed(seed)
torch.set_default_device(device)
......@@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size=(batch_size, 1),
dtype=torch.int64)
rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
......@@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
)
expected_bonus_token_ids = bonus_token_ids.clone()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
if disable_bonus_tokens:
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
......@@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
......@@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int, device: str,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
......@@ -251,8 +238,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
}
for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
......@@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
vocab_size = 30_000
torch.set_default_device(device)
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer,
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)
......@@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
set_random_seed(seed)
helper = _CorrectnessTestHelper(
vocab_size=10,
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer),
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
)
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
......
import itertools
import random
from array import array
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
......@@ -12,8 +11,7 @@ import vllm.envs as envs
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter, is_pin_memory_available
......@@ -59,9 +57,7 @@ def _do_sample(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
......@@ -205,9 +201,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
return sampling_params
def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE,
random.choices(range(0, VOCAB_SIZE), k=num_input)))
seq_data = SequenceData.from_seqs(
random.choices(range(0, VOCAB_SIZE), k=num_input))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
......@@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
......@@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
......@@ -699,11 +690,7 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
......
......@@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
disable_bonus_tokens, strict_mode)
strict_mode)
@pytest.mark.parametrize("k", list(range(1, 6)))
......@@ -154,11 +153,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_uniform_target_distribution_accepts_all_tokens(
seed: int, disable_bonus_tokens: bool, device: str):
seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
......@@ -166,17 +164,14 @@ def test_uniform_target_distribution_accepts_all_tokens(
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted. The test also ensures that the behavior
regarding bonus tokens is consistent with the `disable_bonus_tokens`
flag.
draft tokens being accepted.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
......@@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
# should lead to all draft tokens being accepted. Verify that.
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_temperature_zero_target_distribution(seed: int,
disable_bonus_tokens: bool,
device: str):
def test_temperature_zero_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
......@@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
......@@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
device: str):
def test_mixed_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
......@@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
for sequences with a uniform distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
......@@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
0]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# if disable_bonus_tokens is false then we also accept the bonus tokens.
# we also accept the bonus tokens.
assert torch.all(
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
if disable_bonus_tokens:
assert torch.all(output_token_ids[[1, 3], -1] == -1)
else:
assert torch.all(output_token_ids[[1, 3], -1] != -1)
assert torch.all(output_token_ids[[1, 3], -1] != -1)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
device: str):
def test_accept_tokens_partially(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
......@@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
......@@ -384,14 +361,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the rest as -1
# draft tokens and the recovered token and rest as -1
draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
......@@ -404,16 +378,15 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
assert torch.all(output_token_ids[:, -3:] == -1)
@pytest.mark.parametrize("seed", list(range(1)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_set_non_default_posteriors(seed: int,
disable_bonus_tokens: bool,
device: str):
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
......@@ -425,8 +398,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
......@@ -457,10 +429,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True,
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=0.0,
posterior_alpha=0.0)
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
......@@ -470,25 +439,20 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
device: str):
def test_get_recovered_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
This test verifies that the `_replacement_token_ids` method of the
This test verifies that the `_get_recovered_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as replacements based on the target probability distribution.
as recovered token IDs based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
......@@ -497,14 +461,10 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(
(batch_size, k), dtype=torch.long)
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
dim=1)
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
actual_replacement_tokens = (
typical_acceptance_sampler._replacement_token_ids(target_probs))
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
from itertools import cycle
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple, Union
import pytest
from vllm import LLM, SamplingParams
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs
from ...conftest import cleanup
from ...models.utils import check_logprobs_close, check_outputs_equal
from ...models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal)
from ...utils import RemoteOpenAIServer
PROMPTS = [
......@@ -81,45 +84,77 @@ def get_output_from_llm_generator(
return tokens, token_ids, acceptance_rate
def run_logprob_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: Optional[int] = 0,
temperature: float = 0.0,
logprobs: int = 1):
org_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**baseline_llm_kwargs,
}
sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
logprobs=logprobs)
with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
check_logprobs_close(outputs_0_lst=org_outputs,
outputs_1_lst=sd_outputs,
name_0="org",
name_1="sd")
def check_logprobs_correctness(
spec_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
baseline_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
disable_logprobs: bool = False,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if not disable_logprobs:
return check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=spec_outputs,
name_0="org",
name_1="sd",
)
# Check correctness when disable_logprobs == True
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
# Check generated token logprobs.
spec_logprobs = spec_output[2]
baseline_logprobs = baseline_output[2]
_check_logprobs_when_output_disabled(spec_logprobs,
baseline_logprobs,
is_prompt_logprobs=False)
# Check prompt logprobs too, if they exist
if len(baseline_output) == 4:
assert len(spec_output) == 4
spec_prompt_logprobs = spec_output[3]
baseline_prompt_logprobs = baseline_output[3]
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
baseline_prompt_logprobs,
is_prompt_logprobs=True)
def _check_logprobs_when_output_disabled(
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
is_prompt_logprobs: bool = False,
):
# Prompt logprobs are optional
if is_prompt_logprobs and baseline_logprobs is None:
assert spec_logprobs is None
return
assert spec_logprobs is not None
assert baseline_logprobs is not None
assert len(spec_logprobs) == len(baseline_logprobs)
# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
# First prompt logprob is expected to be None
if is_prompt_logprobs and baseline_pos_logprobs is None:
assert spec_pos_logprobs is None
assert pos == 0
continue
assert spec_pos_logprobs is not None
assert baseline_pos_logprobs is not None
# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert len(spec_pos_logprobs) == 1
(spec_pos_logprob_token_id,
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
assert spec_pos_logprob_token_id in baseline_pos_logprobs
def run_equality_correctness_test(
......@@ -135,7 +170,10 @@ def run_equality_correctness_test(
disable_seed: bool = False,
ignore_eos: bool = True,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
expected_acceptance_rate: Optional[float] = None,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
disable_logprobs: bool = False):
org_args = {
**common_llm_kwargs,
......@@ -157,10 +195,12 @@ def run_equality_correctness_test(
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
ignore_eos=ignore_eos)
ignore_eos=ignore_eos,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)
with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate(prompts, sampling_params)
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
if ensure_all_accepted or expected_acceptance_rate is not None:
......@@ -169,7 +209,7 @@ def run_equality_correctness_test(
'prometheus']
stat_logger.local_interval = -100
sd_outputs = vllm_model.generate(prompts, sampling_params)
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
if ensure_all_accepted or expected_acceptance_rate is not None:
acceptance_rate = (stat_logger.metrics.
......@@ -185,11 +225,16 @@ def run_equality_correctness_test(
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
check_outputs_equal(outputs_0_lst=org_outputs,
outputs_1_lst=sd_outputs,
# Only pass token entries, not the logprobs
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
outputs_1_lst=[out[0:2] for out in sd_outputs],
name_0="org",
name_1="sd")
# Check logprobs if requested
if logprobs is not None or prompt_logprobs is not None:
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
def run_equality_correctness_test_tp(model,
common_llm_kwargs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment