Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
...@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None: ...@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
fc1 = model.model.decoder.layers[0].fc1 fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.linear_method, Fp8LinearMethod) assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn assert fc1.weight.dtype == torch.float8_e4m3fn
"""Make sure ignore_eos works.
Run `pytest tests/samplers/test_ignore_eos.py`.
"""
import pytest
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [1024])
def test_beam_search_single_input(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
example_prompts = "1 + 1 is"
vllm_model = vllm_runner(model, dtype=dtype)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
ignore_eos_output = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params)
print(len(ignore_eos_output[0].outputs[0].token_ids))
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0
...@@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"] ...@@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
def test_get_prompt_logprobs( def test_get_prompt_logprobs(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
model, model,
dtype, dtype,
chunked_prefill_token_size: int,
num_top_logprobs: int,
example_prompts, example_prompts,
): ):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
max_tokens = 5 max_tokens = 5
num_top_logprobs = 6
hf_model = hf_runner(model, dtype=dtype) hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs( hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts, example_prompts,
...@@ -25,10 +36,17 @@ def test_get_prompt_logprobs( ...@@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
) )
del hf_model del hf_model
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) vllm_model = vllm_runner(
model,
dtype=dtype,
max_logprobs=num_top_logprobs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens, vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs, logprobs=num_top_logprobs,
prompt_logprobs=5, prompt_logprobs=num_top_logprobs,
temperature=0.0) temperature=0.0)
vllm_results = vllm_model.model.generate( vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params) example_prompts, sampling_params=vllm_sampling_params)
...@@ -52,9 +70,18 @@ def test_get_prompt_logprobs( ...@@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
"The output text from the top logprob for each token position " "The output text from the top logprob for each token position "
"should be the same as the output text in the result.") "should be the same as the output text in the result.")
# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
for prompt_logprobs in result.prompt_logprobs[1:]:
# If the prompt token is not included in the top X
# logprob, it can return 1 more data
assert (len(prompt_logprobs) == num_top_logprobs
or len(prompt_logprobs) == num_top_logprobs + 1)
# Test whether prompt logprobs are consistent with HF # Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs # Check prompt logprobs
# The first prompt logprob is always None, so we compare it from 1:.
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items(): for token_id, logprob in vllm_prompt_logprob_dict.items():
...@@ -74,6 +101,17 @@ def test_get_prompt_logprobs( ...@@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
"The token should be decoded by the time it is returned " "The token should be decoded by the time it is returned "
" to the user.") " to the user.")
# Test if prompt logprobs are correctly set.
for vllm_result in vllm_results:
token_ids = vllm_result.prompt_token_ids
prompt_logprobs = vllm_result.prompt_logprobs
# The first token doesn't have logprob.
assert prompt_logprobs[0] is None
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
assert token_id in logprob_dict
def test_max_logprobs(): def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1) runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from transformers import GenerationConfig, GenerationMixin from transformers import GenerationConfig, GenerationMixin
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter from vllm.utils import Counter
...@@ -54,9 +55,10 @@ def _do_sample( ...@@ -54,9 +55,10 @@ def _do_sample(
sampler: MockLogitsSampler, sampler: MockLogitsSampler,
model_runner: ModelRunner, model_runner: ModelRunner,
sampling_params: SamplingParams, sampling_params: SamplingParams,
device: str,
): ):
seq_group_metadata_list = [] seq_group_metadata_list = []
prompt_lens = [] seq_lens = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -66,11 +68,14 @@ def _do_sample( ...@@ -66,11 +68,14 @@ def _do_sample(
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
prompt_lens, seq_group_metadata_list,
subquery_lens=prompt_lens) seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
...@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str): ...@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params) sampling_params, device)
expected = torch.argmax(fake_logits, dim=-1) expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
...@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str): ...@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
n=random.randint(1, 10), n=random.randint(1, 10),
) )
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params) sampling_params, device)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
...@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str): ...@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
seed=random.randint(0, 10000), seed=random.randint(0, 10000),
) )
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params) sampling_params, device)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
...@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): ...@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
seed=random.randint(0, 10000), seed=random.randint(0, 10000),
) )
first_sampler_output = _do_sample(batch_size, fake_logits, sampler, first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params) model_runner, sampling_params, device)
second_sampler_output = _do_sample(batch_size, fake_logits, sampler, second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params) model_runner, sampling_params, device)
assert first_sampler_output == second_sampler_output assert first_sampler_output == second_sampler_output
...@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str): ...@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
best_of=2, best_of=2,
use_beam_search=True, use_beam_search=True,
) )
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
device)
# no assertion here as I am not sure how to determine whether # no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests # the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler # whether there are no exceptions in the sampler
...@@ -201,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -201,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens, def create_sampling_params(min_tokens,
eos_token_id=0, eos_token_id=0,
*, *,
stop_token_ids: Optional[List[str]] = None, stop_token_ids: Optional[List[int]] = None,
prompt_logprobs: Optional[int] = None): prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams( sampling_params = SamplingParams(
min_tokens=min_tokens, min_tokens=min_tokens,
...@@ -210,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -210,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# requesting prompt_logprobs changes the structure of `logits` # requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
) )
sampling_params.eos_token_id = eos_token_id sampling_params.all_stop_token_ids.add(eos_token_id)
return sampling_params return sampling_params
def create_sequence_data(num_input=3, num_generated=0): def create_sequence_data(num_input=3, num_generated=0):
...@@ -415,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -415,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"Invalid test case, need seq_group_metadata_list" "Invalid test case, need seq_group_metadata_list"
batch_size = 0 batch_size = 0
prompt_lens = [] seq_lens = []
sampling_params_per_row = [] sampling_params_per_row = []
for sgm in seq_group_metadata_list: for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params sampling_params = sgm.sampling_params
...@@ -425,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -425,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# a prompt seq_group has only one sequence # a prompt seq_group has only one sequence
seq_data = next(iter(sgm.seq_data.values())) seq_data = next(iter(sgm.seq_data.values()))
prompt_len = seq_data.get_prompt_len() prompt_len = seq_data.get_prompt_len()
prompt_lens.append(prompt_len) seq_lens.append(prompt_len)
if sgm.sampling_params.prompt_logprobs: if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in # with prompt_logprobs each token in the prompt has a row in
...@@ -443,20 +449,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -443,20 +449,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"batch size") "batch size")
_, fake_logits, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens=prompt_lens if prompt_lens else None, seq_lens=seq_lens if seq_lens else None,
subquery_lens=prompt_lens if prompt_lens else None) query_lens=seq_lens if seq_lens else None,
device=device,
pin_memory=model_runner.pin_memory)
# the logits tensor is modified in-place by the sampler # the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
for logits_idx, (should_penalize, sampling_params) in enumerate( for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_row)): zip(expected_penalization, sampling_params_per_row)):
tokens_to_check = [sampling_params.eos_token_id] tokens_to_check = sampling_params.all_stop_token_ids
if sampling_params.stop_token_ids:
tokens_to_check.extend(sampling_params.stop_token_ids)
tokens_to_check = set(tokens_to_check)
if should_penalize: if should_penalize:
for token_id in tokens_to_check: for token_id in tokens_to_check:
...@@ -492,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -492,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_group_metadata_list = [] seq_group_metadata_list = []
expected_tokens: List[Optional[List[int]]] = [] expected_tokens: List[Optional[List[int]]] = []
prompt_lens = [] seq_lens = []
for i in range(batch_size): for i in range(batch_size):
expected: Optional[List[int]] = None expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3) sampling_type = random.randint(0, 3)
...@@ -527,11 +532,15 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -527,11 +532,15 @@ def test_sampler_mixed(seed: int, device: str):
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
def test_sampling(model_runner: ModelRunner): def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
sampler_output = sampler(logits=fake_logits, sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata)
...@@ -566,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -566,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
# Shuffle the batch and resample # Shuffle the batch and resample
target_index = list(range(batch_size)) target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list, for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens): expected_tokens, seq_lens):
random.Random(seed).shuffle(list_to_shuffle) random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index) target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index) input_tensor.data = input_tensor.index_select(0, target_index)
...@@ -611,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -611,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert len(warpers) == 2 # top_p and top_k assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = [] seq_group_metadata_list = []
prompt_lens = [] seq_lens = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -625,11 +634,14 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -625,11 +634,14 @@ def test_sampler_top_k_top_p(seed: int, device: str):
), ),
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
prompt_lens, seq_group_metadata_list,
subquery_lens=prompt_lens) seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
sample_probs = None sample_probs = None
......
from typing import List, Tuple import asyncio
import time
from itertools import cycle
from typing import Dict, List, Optional, Tuple, Union
import pytest import pytest
import ray
import torch
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)
from tests.conftest import cleanup from tests.conftest import cleanup
from vllm import LLM from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid
class AsyncLLM:
"""AsyncLLM
Note: Current LLM class in vllm don't support async mode, for test purpose,
we implement async one in here. Maybe we could move to
vllm/entrypoints/llm.py in future.
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
to make to work in async mode.
"""
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
self.engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_seq_len_to_capture=max_seq_len_to_capture,
engine_use_ray=True,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
self.request_counter = Counter()
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
llm_engine = AsyncLLMEngine.from_engine_args(
self.engine_args, usage_context=UsageContext.LLM_CLASS)
if prompts is None:
raise ValueError("prompts must be provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if prompts is not None:
num_requests = len(prompts)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
elif isinstance(sampling_params,
list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and "
"sampling_params must be the same.")
async def get_output(prompt, sampling_param) -> str:
request_id = random_uuid()
results_generator = llm_engine.generate(prompt, sampling_param,
request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
return final_output
outputs = []
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
res = asyncio.run(get_output(prompt, sampling_params))
outputs.append(res)
finally:
ray.shutdown()
return outputs
@pytest.fixture @pytest.fixture
...@@ -35,9 +157,20 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs, ...@@ -35,9 +157,20 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
test_name = request.node.name test_name = request.node.name
def generator_inner(): def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs)
wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
use_async = False
if "use_async" in kwargs:
use_async = kwargs.pop("use_async")
print(f'{use_async=}')
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
set_random_seed(seed) set_random_seed(seed)
yield llm yield llm
...@@ -64,3 +197,109 @@ def get_output_from_llm_generator( ...@@ -64,3 +197,109 @@ def get_output_from_llm_generator(
del llm del llm
return tokens, token_ids return tokens, token_ids
def get_logprobs_from_llm_generator(
llm_generator, prompts,
sampling_params) -> List[List[Dict[int, Logprob]]]:
"""Returns a dict of (token_id: Logprob) for each generated position, for
each sequence in the batch.
"""
for llm in llm_generator():
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
del llm
return logprobs
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit()
start_time = time.time()
while True:
output = {}
output_raw = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
print('gpu memory used (GB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')
dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
break
if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
time.sleep(5)
...@@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator): ...@@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator):
temperature=temperature, temperature=temperature,
) )
with pytest.raises(AssertionError, try:
match="Speculative decoding not yet supported for "): with pytest.raises(
get_output_from_llm_generator(test_llm_generator, prompts, AssertionError,
sampling_params) match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
finally:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import ray
ray.shutdown()
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
import math
from itertools import cycle
import pytest
from vllm import SamplingParams
from .conftest import get_logprobs_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"max_logprobs": 6,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
7,
])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify output logprobs are equal with and without speculative decoding.
"""
run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"max_logprobs": 6,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("num_logprobs", [6])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
7,
])
@pytest.mark.parametrize("seed", [1])
def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int,
num_logprobs: int):
"""Verify output logprobs are equal with and without spec decode.
This specifies a number of logprobs >1.
"""
run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
logprob_rank=num_logprobs)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_when_skip_speculation(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
batch_size = 8
max_output_len = output_len
force_output_len = True
logprob_rank = 5
temperature = 1.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
logprobs=logprob_rank,
)
spec_batch_logprobs = get_logprobs_from_llm_generator(
test_llm_generator, prompts, sampling_params)
num_returned_logprobs = [
len(logprob_dict) for seq_logprobs in spec_batch_logprobs
for logprob_dict in seq_logprobs
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert any([
num_returned > logprob_rank for num_returned in num_returned_logprobs
])
def run_greedy_logprobs_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
logprob_rank: int = 1):
"""Helper method that compares the logprobs outputs of both the baseline LLM
and the test LLM. It asserts greedy equality of the logprobs when the
temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
logprobs=logprob_rank,
)
spec_batch_logprobs = get_logprobs_from_llm_generator(
test_llm_generator, prompts, sampling_params)
baseline_batch_logprobs = get_logprobs_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_logprobs) == len(prompts)
assert len(spec_batch_logprobs) == len(prompts)
# For each sequence in the batch.
for i, (baseline_logprobs, spec_logprobs) in enumerate(
zip(baseline_batch_logprobs, spec_batch_logprobs)):
assert len(spec_logprobs) == len(baseline_logprobs)
# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
# Map rank to token/logprob in spec output.
spec_rank_to_token_id = {
value.rank: key
for key, value in spec_pos_logprobs.items()
}
spec_rank_to_logprob = {
value.rank: value.logprob
for key, value in spec_pos_logprobs.items()
}
# Map rank to token/logprob in baseline output.
baseline_rank_to_token_id = {
value.rank: key
for key, value in baseline_pos_logprobs.items()
}
baseline_rank_to_logprob = {
value.rank: value.logprob
for key, value in baseline_pos_logprobs.items()
}
# Assert set of ranks returned is equal.
assert set(spec_rank_to_token_id.keys()) == set(
baseline_rank_to_token_id.keys())
# Assert each logprob/token id is correct, keyed by rank.
for rank in sorted(set(spec_rank_to_token_id.keys())):
assert spec_rank_to_token_id[
rank] == baseline_rank_to_token_id[rank], f"{rank}"
assert math.isclose(
a=spec_rank_to_logprob[rank],
b=baseline_rank_to_logprob[rank],
abs_tol=1e-1,
)
...@@ -35,7 +35,8 @@ from transformers import AutoTokenizer ...@@ -35,7 +35,8 @@ from transformers import AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from .conftest import get_output_from_llm_generator from .conftest import (get_output_from_llm_generator,
run_greedy_equality_correctness_test)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -49,7 +50,7 @@ from .conftest import get_output_from_llm_generator ...@@ -49,7 +50,7 @@ from .conftest import get_output_from_llm_generator
"enforce_eager": True, "enforce_eager": True,
# Required for spec decode. # Required for spec decode.
"use_v2_block_manager": True "use_v2_block_manager": True,
}]) }])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"per_test_common_llm_kwargs", "per_test_common_llm_kwargs",
...@@ -109,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, ...@@ -109,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
assert actual_tokens.strip() == expected_tokens.strip() assert actual_tokens.strip() == expected_tokens.strip()
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Use AsyncLLM engine
"use_async": True,
}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_with_async_engine(test_llm_generator,
baseline_llm_generator,
batch_size: int):
"""Verify spec decode works well with async LLM engine.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
...@@ -538,60 +577,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, ...@@ -538,60 +577,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size, batch_size,
max_output_len=output_len, max_output_len=output_len,
force_output_len=True) force_output_len=True)
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
...@@ -5,13 +5,12 @@ import pytest ...@@ -5,13 +5,12 @@ import pytest
import torch import torch
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer, from vllm.spec_decode.multi_step_worker import MultiStepWorker
MultiStepWorker) from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from .utils import (assert_logprobs_dict_allclose, create_batch, from .utils import (assert_logprobs_dict_allclose, create_batch,
create_execute_model_data,
create_seq_group_metadata_from_prompts, create_worker, create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache) patch_execute_model_with_seeds, zero_kv_cache)
...@@ -34,7 +33,7 @@ def test_assert_enough_kv_space(num_steps: int): ...@@ -34,7 +33,7 @@ def test_assert_enough_kv_space(num_steps: int):
list(range(block_size * 2)), list(range(block_size * 2)),
] ]
final_seq_lens = [ final_prompt_lens = [
len(prompt + output) + num_steps len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens) for prompt, output in zip(prompts, prev_output_tokens)
] ]
...@@ -43,7 +42,7 @@ def test_assert_enough_kv_space(num_steps: int): ...@@ -43,7 +42,7 @@ def test_assert_enough_kv_space(num_steps: int):
prompts, prompts,
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
final_seq_lens, final_prompt_lens,
continuations=prev_output_tokens) continuations=prev_output_tokens)
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
...@@ -103,29 +102,34 @@ def test_same_output_for_single_step(): ...@@ -103,29 +102,34 @@ def test_same_output_for_single_step():
[6, 7, 8, 9, 10], [6, 7, 8, 9, 10],
] ]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts] final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_execute_model_data = create_execute_model_data( multi_step_seq_group = create_seq_group_metadata_from_prompts(
seq_group_metadata_list=create_seq_group_metadata_from_prompts( prompts,
prompts, num_gpu_blocks, block_size, num_gpu_blocks,
final_seq_lens=final_seq_lens)) block_size,
final_prompt_lens=final_prompt_lens)
single_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
zero_kv_cache(multi_step_worker.cache_engine) zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
actual_output = multi_step_worker.execute_model_multi_step( actual_output, _ = multi_step_worker.sampler_output(
**multi_step_execute_model_data.to_dict(), num_steps=num_steps) execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps)
assert len(actual_output) == num_steps assert len(actual_output) == num_steps
actual_output = actual_output[0] actual_output = actual_output[0]
single_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
expected_output = worker.execute_model( expected_output = worker.execute_model(
**single_step_execute_model_data.to_dict(), )[0] execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=single_step_seq_group))[0]
actual_token_ids = [ actual_token_ids = [
output.samples[0].output_token for output in actual_output output.samples[0].output_token for output in actual_output
...@@ -181,7 +185,7 @@ def test_same_output_for_multi_step(): ...@@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
random.randint(0, 1000) for _ in range(random.randint(10, 20)) random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)] ] for _ in range(10)]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts] final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds( multi_step_worker.execute_model = patch_execute_model_with_seeds(
...@@ -189,19 +193,20 @@ def test_same_output_for_multi_step(): ...@@ -189,19 +193,20 @@ def test_same_output_for_multi_step():
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
continuations = [[1] for _ in prompts] continuations = [[1] for _ in prompts]
execute_model_data = create_execute_model_data( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
create_seq_group_metadata_from_prompts( prompts,
prompts, num_gpu_blocks,
num_gpu_blocks, block_size,
block_size, continuations=continuations,
continuations=continuations, final_prompt_lens=final_prompt_lens)
final_seq_lens=final_seq_lens), )
# Run multi-step. # Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine) zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
multi_step_output = multi_step_worker.execute_model_multi_step( multi_step_output, _ = multi_step_worker.sampler_output(
**execute_model_data.to_dict(), num_steps=num_steps) execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps)
# Run single-step repeatedly. # Run single-step repeatedly.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
...@@ -211,16 +216,16 @@ def test_same_output_for_multi_step(): ...@@ -211,16 +216,16 @@ def test_same_output_for_multi_step():
for _ in multi_step_output: for _ in multi_step_output:
execute_model_data = create_execute_model_data( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
create_seq_group_metadata_from_prompts( prompts,
prompts, num_gpu_blocks,
num_gpu_blocks, block_size,
block_size, continuations=continuations,
continuations=continuations, final_prompt_lens=final_prompt_lens)
final_seq_lens=final_seq_lens))
single_step_output.extend( single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), )) worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data. # Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]): for i, seq_group_output in enumerate(single_step_output[-1]):
...@@ -266,7 +271,7 @@ def test_same_output_for_multi_step(): ...@@ -266,7 +271,7 @@ def test_same_output_for_multi_step():
@torch.inference_mode() @torch.inference_mode()
def test_draft_proposals_full_speculation_len(): def test_draft_proposals_full_speculation_len():
"""Verify DraftModelTop1Proposer correctly handles case where all sequences """Verify Top1Proposer correctly handles case where all sequences
can speculate. can speculate.
""" """
k = 10 k = 10
...@@ -275,33 +280,36 @@ def test_draft_proposals_full_speculation_len(): ...@@ -275,33 +280,36 @@ def test_draft_proposals_full_speculation_len():
device = 'cuda:0' device = 'cuda:0'
draft_worker = MagicMock() draft_worker = MagicMock()
proposer = DraftModelTop1Proposer( proposer = Top1Proposer(
draft_worker=draft_worker, worker=draft_worker,
device=device, device=device,
max_model_len=2048,
vocab_size=vocab_size, vocab_size=vocab_size,
max_proposal_len=2048,
) )
draft_worker.execute_model_multi_step.return_value = [ draft_worker.sampler_output.return_value = [
SamplerOutput( SamplerOutput(
outputs=[], outputs=[],
sampled_token_probs=torch.rand(batch_size, sampled_token_probs=torch.rand(batch_size,
vocab_size, vocab_size,
device=device, device=device,
dtype=torch.float32), dtype=torch.float32),
logprobs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(low=0, sampled_token_ids=torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, ), size=(batch_size, ),
device=device, device=device,
dtype=torch.long), dtype=torch.long),
) for _ in range(k) ) for _ in range(k)
] ], True
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**execute_model_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
max_proposal_len=k, num_lookahead_slots=k), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -315,7 +323,7 @@ def test_draft_proposals_full_speculation_len(): ...@@ -315,7 +323,7 @@ def test_draft_proposals_full_speculation_len():
@torch.inference_mode() @torch.inference_mode()
def test_draft_proposals_no_speculations(): def test_draft_proposals_no_speculations():
"""Verify DraftModelTop1Proposer correctly handles case where no sequences """Verify Top1Proposer correctly handles case where no sequences
can speculate. can speculate.
""" """
k = 10 k = 10
...@@ -325,21 +333,20 @@ def test_draft_proposals_no_speculations(): ...@@ -325,21 +333,20 @@ def test_draft_proposals_no_speculations():
prompt_len = 10 prompt_len = 10
draft_worker = MagicMock() draft_worker = MagicMock()
proposer = DraftModelTop1Proposer( proposer = Top1Proposer(
draft_worker=draft_worker, worker=draft_worker,
device=device, device=device,
max_model_len=prompt_len + k - 1,
vocab_size=vocab_size, vocab_size=vocab_size,
max_proposal_len=prompt_len + k - 1,
) )
execute_model_data, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
prompt_len=prompt_len) prompt_len=prompt_len)
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**execute_model_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
max_proposal_len=k, num_lookahead_slots=k), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
...@@ -353,7 +360,7 @@ def test_draft_proposals_no_speculations(): ...@@ -353,7 +360,7 @@ def test_draft_proposals_no_speculations():
@torch.inference_mode() @torch.inference_mode()
def test_draft_proposals_mixed_k(): def test_draft_proposals_mixed_k():
"""Verify DraftModelTop1Proposer correctly handles case some sequences can """Verify Top1Proposer correctly handles case some sequences can
speculate and some can't. speculate and some can't.
""" """
k = 10 k = 10
...@@ -374,20 +381,24 @@ def test_draft_proposals_mixed_k(): ...@@ -374,20 +381,24 @@ def test_draft_proposals_mixed_k():
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len] for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock() draft_worker = MagicMock()
proposer = DraftModelTop1Proposer( proposer = Top1Proposer(
draft_worker=draft_worker, worker=draft_worker,
device=device, device=device,
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
vocab_size=vocab_size, vocab_size=vocab_size,
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
) )
draft_worker.execute_model_multi_step.return_value = [ draft_worker.sampler_output.return_value = [
SamplerOutput( SamplerOutput(
outputs=[], outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs, sampled_token_probs=torch.rand(expected_num_proposal_seqs,
vocab_size, vocab_size,
device=device, device=device,
dtype=torch.float32), dtype=torch.float32),
logprobs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint( sampled_token_ids=torch.randint(
low=0, low=0,
high=vocab_size, high=vocab_size,
...@@ -395,19 +406,18 @@ def test_draft_proposals_mixed_k(): ...@@ -395,19 +406,18 @@ def test_draft_proposals_mixed_k():
device=device, device=device,
dtype=torch.long), dtype=torch.long),
) for _ in range(k) ) for _ in range(k)
] ], True
execute_model_data, _, _ = create_batch( seq_group_metadata_list, _, _ = create_batch(
batch_size, batch_size,
k, k,
prompt_len=prompt_len, prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len, prev_output_token_len=prev_output_token_len,
) )
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**execute_model_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
max_proposal_len=k, num_lookahead_slots=k), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
......
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import create_seq_group_metadata_from_prompts, create_worker
def test_ngram_algo_correctness_for_single_no_match():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([1])
assert proposals.proposal_lens.tolist() == [0]
def test_ngram_algo_correctness_for_batches_not_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
# shall find no candidate as exceed max_proposal_len
[
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
38, 31, 32, 33
],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([5])
assert proposals.proposal_lens.tolist(
) == [proposal_len for _ in range(4)] + [0]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == 0
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
assert proposals.proposal_token_ids[4][i] == -1
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([3])
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
...@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker ...@@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly) split_num_cache_blocks_evenly)
from .utils import (ExecuteModelData, create_batch, create_sampler_output_list, from .utils import create_batch, create_sampler_output_list, mock_worker
mock_worker)
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
...@@ -33,27 +32,22 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): ...@@ -33,27 +32,22 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector) metrics_collector)
exception_secret = 'artifical stop' exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), worker.execute_model(execute_model_req=execute_model_req)
num_lookahead_slots=k)
call_args_list = draft_worker.get_spec_proposals.call_args_list call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
for args, _ in call_args_list: for args, _ in call_args_list:
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, actual_execute_model_data = args[0]
blocks_to_copy, actual_k) = args assert actual_execute_model_data == execute_model_req
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy)
assert actual_execute_model_data == execute_model_data
assert actual_k == k
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
...@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64, proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k device='cuda') * k
execute_model_data, prompts, prev_output_tokens = create_batch( seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
batch_size, k) batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals( draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
...@@ -101,24 +95,24 @@ def test_correctly_calls_target_model(k: int, batch_size: int): ...@@ -101,24 +95,24 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_probs=proposal_probs, proposal_probs=proposal_probs,
proposal_lens=proposal_lens) proposal_lens=proposal_lens)
exception_secret = 'artifical stop' exception_secret = 'artificial stop'
target_worker.execute_model.side_effect = ValueError(exception_secret) target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), worker.execute_model(execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts = [] seen_contexts = []
call_args_list = target_worker.execute_model.call_args_list call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
for args, kwargs in call_args_list: for _, kwargs in call_args_list:
target_execute_model_data = ExecuteModelData.from_dict(kwargs) seq_group_metadata_list = kwargs[
"execute_model_req"].seq_group_metadata_list
assert len(target_execute_model_data.seq_group_metadata_list) == ( assert len(seq_group_metadata_list) == (k + 1) * batch_size
k + 1) * batch_size for seq_group_metadata in seq_group_metadata_list:
for seq_group_metadata in (
target_execute_model_data.seq_group_metadata_list):
for seq_data in seq_group_metadata.seq_data.values(): for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids()) seen_contexts.append(seq_data.get_token_ids())
...@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64, proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals( draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids, proposal_token_ids=proposal_token_ids,
...@@ -192,17 +186,24 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): ...@@ -192,17 +186,24 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size, vocab_size,
dtype=torch.float32, dtype=torch.float32,
device='cuda') device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artifical stop' exception_secret = 'artificial stop'
rejection_sampler.side_effect = ValueError(exception_secret) rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), worker.execute_model(execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert len(rejection_sampler.call_args_list) == 1 assert len(rejection_sampler.call_args_list) == 1
_, kwargs = rejection_sampler.call_args_list[0] _, kwargs = rejection_sampler.call_args_list[0]
...@@ -256,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -256,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64, proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals( draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids, proposal_token_ids=proposal_token_ids,
...@@ -273,8 +274,14 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -273,8 +274,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size, vocab_size,
dtype=torch.float32, dtype=torch.float32,
device='cuda') device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
...@@ -290,15 +297,18 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -290,15 +297,18 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler.return_value = rejection_sampler_output rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
expected_output = create_sampler_output_list( expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) token_ids=rejection_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)])
seq_ids = [ seq_ids = [
next(iter(seq_group_metadata.seq_data.keys())) next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in execute_model_data.seq_group_metadata_list for seq_group_metadata in seq_group_metadata_list
] ]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
...@@ -328,7 +338,6 @@ def test_correctly_formats_output(k: int, batch_size: int): ...@@ -328,7 +338,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
continue continue
assert actual_by_step[i].output_token == expected_by_step[ assert actual_by_step[i].output_token == expected_by_step[
i].output_token i].output_token
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
@pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('k', [1, 2])
...@@ -370,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -370,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
proposal_lens = torch.ones(batch_size, dtype=torch.int64, proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals( draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids, proposal_token_ids=proposal_token_ids,
...@@ -387,8 +396,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -387,8 +396,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size, vocab_size,
dtype=torch.float32, dtype=torch.float32,
device='cuda') device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids, target_output = create_sampler_output_list(target_token_ids,
target_token_probs) target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
...@@ -409,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): ...@@ -409,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
metrics_collector.maybe_collect_rejsample_metrics.return_value = ( metrics_collector.maybe_collect_rejsample_metrics.return_value = (
mock_rejsample_metrics) mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = ( call_args_list = (
...@@ -443,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int): ...@@ -443,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector) metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch( seq_group_metadata_list, _, _ = create_batch(batch_size,
batch_size, k, prev_output_token_len=0) k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(**execute_model_data.to_dict(), out = worker.execute_model(execute_model_req=execute_model_req)
num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[ assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with( draft_worker.execute_model.assert_called_once_with(execute_model_req)
**execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('k', [0, 5])
...@@ -484,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int): ...@@ -484,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector) metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch( seq_group_metadata_list, _, _ = create_batch(batch_size,
batch_size, k, prev_output_token_len=0) k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(**execute_model_data.to_dict(), out = worker.execute_model(execute_model_req=execute_model_req)
num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[ assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with( draft_worker.execute_model.assert_called_once_with(execute_model_req)
**execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
......
from dataclasses import dataclass, fields
from itertools import count from itertools import count
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
...@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine ...@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
@dataclass
class ExecuteModelData:
"""Helper data structure which facilitates cleaner tests.
"""
seq_group_metadata_list: List[SequenceGroupMetadata]
blocks_to_swap_in: Dict[int, int]
blocks_to_swap_out: Dict[int, int]
blocks_to_copy: Dict[int, List[int]]
def to_dict(self):
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
@classmethod
def from_dict(cls, d):
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
return cls(**cleaned)
def round_up_to_next_block(seq_len: int, block_size: int) -> int: def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size return (seq_len + block_size - 1) // block_size
def create_execute_model_data(
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, int]] = None,
) -> ExecuteModelData:
if blocks_to_swap_in is None:
blocks_to_swap_in = {}
if blocks_to_swap_out is None:
blocks_to_swap_out = {}
if blocks_to_copy is None:
blocks_to_copy = {}
return ExecuteModelData(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
def mock_worker(cls=None, def mock_worker(cls=None,
vocab_size: int = 30_000, vocab_size: int = 30_000,
max_model_len: int = 2048, max_model_len: int = 2048,
...@@ -144,7 +103,7 @@ def create_seq_group_metadata_from_prompts( ...@@ -144,7 +103,7 @@ def create_seq_group_metadata_from_prompts(
prompts: List[List[int]], prompts: List[List[int]],
num_gpu_blocks: int, num_gpu_blocks: int,
block_size: int, block_size: int,
final_seq_lens: List[int], final_prompt_lens: List[int],
continuations: Optional[List[List[int]]] = None, continuations: Optional[List[List[int]]] = None,
seq_ids: Optional[List[int]] = None, seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
...@@ -162,7 +121,7 @@ def create_seq_group_metadata_from_prompts( ...@@ -162,7 +121,7 @@ def create_seq_group_metadata_from_prompts(
free_gpu_blocks.pop() free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size)) for _ in range(round_up_to_next_block(final_len, block_size))
] ]
for i, final_len in enumerate(final_seq_lens) for i, final_len in enumerate(final_prompt_lens)
} }
return [ return [
...@@ -201,6 +160,7 @@ def assert_logprobs_dict_allclose( ...@@ -201,6 +160,7 @@ def assert_logprobs_dict_allclose(
def create_sampler_output_list( def create_sampler_output_list(
token_ids: torch.Tensor, token_ids: torch.Tensor,
probs: Iterable[Optional[torch.Tensor]], probs: Iterable[Optional[torch.Tensor]],
logprobs: Iterable[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist() token_ids_by_step = token_ids.tolist()
...@@ -222,6 +182,7 @@ def create_sampler_output_list( ...@@ -222,6 +182,7 @@ def create_sampler_output_list(
) for seq_index, token_id in enumerate(token_ids_by_step[step]) ) for seq_index, token_id in enumerate(token_ids_by_step[step])
], ],
sampled_token_probs=probs[step], sampled_token_probs=probs[step],
logprobs=logprobs[step],
sampled_token_ids=token_ids[step]) sampled_token_ids=token_ids[step])
for step in range(num_steps) for step in range(num_steps)
] ]
...@@ -251,13 +212,12 @@ def create_batch(batch_size, ...@@ -251,13 +212,12 @@ def create_batch(batch_size,
prev_output_tokens = [[ prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len) next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)] ] for _ in range(batch_size)]
final_seq_lens = [ final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1 len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens) for prompt, prev_output_token in zip(prompts, prev_output_tokens)
] ]
execute_model_data = create_execute_model_data( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, prompts, num_gpu_blocks, block_size, final_prompt_lens,
block_size, final_seq_lens, prev_output_tokens, seq_ids)
prev_output_tokens, seq_ids), ) return seq_group_metadata_list, prompts, prev_output_tokens
return execute_model_data, prompts, prev_output_tokens
...@@ -6,14 +6,14 @@ import uuid ...@@ -6,14 +6,14 @@ import uuid
from functools import partial from functools import partial
from typing import Type from typing import Type
import torch
import torch.nn as nn import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io) TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
...@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1] ...@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080" os.environ["MASTER_PORT"] = "8080"
torch.distributed.init_process_group(world_size=1, rank=0) init_distributed_environment(world_size=1, rank=0, local_rank=0)
initialize_model_parallel() initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None keyfile = args.keyfile if args.keyfile else None
......
...@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config): ...@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_agent_instance.deserialize.return_value = MagicMock() mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config, result = load_with_tensorizer(tensorizer_config,
linear_method=mock_linear_method) quant_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config, mock_agent.assert_called_once_with(tensorizer_config,
linear_method=mock_linear_method) quant_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once() mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value assert result == mock_agent_instance.deserialize.return_value
......
import json
import logging
import os import os
import sys import sys
import tempfile import tempfile
from json.decoder import JSONDecodeError
from tempfile import NamedTemporaryFile
from typing import Any
from unittest.mock import patch
from uuid import uuid4
from vllm.logger import enable_trace_function_call import pytest
from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger,
enable_trace_function_call, init_logger)
from vllm.logging import NewLineFormatter
def f1(x): def f1(x):
...@@ -25,3 +36,179 @@ def test_trace_function_call(): ...@@ -25,3 +36,179 @@ def test_trace_function_call():
assert "f2" in content assert "f2" in content
sys.settrace(None) sys.settrace(None)
os.remove(path) os.remove(path)
def test_default_vllm_root_logger_configuration():
"""This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and
VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default
behavior is activated."""
logger = logging.getLogger("vllm")
assert logger.level == logging.DEBUG
assert not logger.propagate
handler = logger.handlers[0]
assert handler.stream == sys.stdout
assert handler.level == logging.INFO
formatter = handler.formatter
assert formatter is not None
assert isinstance(formatter, NewLineFormatter)
assert formatter._fmt == _FORMAT
assert formatter.datefmt == _DATE_FORMAT
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None)
def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger():
"""This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and
VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default
behavior is activated."""
root_logger = logging.getLogger("vllm")
root_handler = root_logger.handlers[0]
unique_name = f"vllm.{uuid4()}"
logger = init_logger(unique_name)
assert logger.name == unique_name
assert logger.level == logging.NOTSET
assert not logger.handlers
assert logger.propagate
message = "Hello, world!"
with patch.object(root_handler, "emit") as root_handle_mock:
logger.info(message)
root_handle_mock.assert_called_once()
_, call_args, _ = root_handle_mock.mock_calls[0]
log_record = call_args[0]
assert unique_name == log_record.name
assert message == log_record.msg
assert message == log_record.msg
assert log_record.levelno == logging.INFO
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0)
@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None)
def test_logger_configuring_can_be_disabled():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
with patch("logging.config.dictConfig") as dict_config_mock:
_configure_vllm_root_logger()
dict_config_mock.assert_not_called()
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
@patch(
"vllm.logger.VLLM_LOGGING_CONFIG_PATH",
"/if/there/is/a/file/here/then/you/did/this/to/yourself.json",
)
def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
with pytest.raises(RuntimeError) as ex_info:
_configure_vllm_root_logger()
assert ex_info.type == RuntimeError
assert "File does not exist" in str(ex_info)
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
def test_an_error_is_raised_when_custom_logging_config_is_invalid_json():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file:
logging_config_file.write("---\nloggers: []\nversion: 1")
logging_config_file.flush()
with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH",
logging_config_file.name):
with pytest.raises(JSONDecodeError) as ex_info:
_configure_vllm_root_logger()
assert ex_info.type == JSONDecodeError
assert "Expecting value" in str(ex_info)
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
@pytest.mark.parametrize("unexpected_config", (
"Invalid string",
[{
"version": 1,
"loggers": []
}],
0,
))
def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json(
unexpected_config: Any):
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however it fails before any change in behavior or
configuration occurs."""
with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file:
logging_config_file.write(json.dumps(unexpected_config))
logging_config_file.flush()
with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH",
logging_config_file.name):
with pytest.raises(ValueError) as ex_info:
_configure_vllm_root_logger()
assert ex_info.type == ValueError
assert "Invalid logging config. Expected Dict, got" in str(ex_info)
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1)
def test_custom_logging_config_is_parsed_and_used_when_provided():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
valid_logging_config = {
"loggers": {
"vllm.test_logger.logger": {
"handlers": [],
"propagate": False,
}
},
"version": 1
}
with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file:
logging_config_file.write(json.dumps(valid_logging_config))
logging_config_file.flush()
with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH",
logging_config_file.name), patch(
"logging.config.dictConfig") as dict_config_mock:
_configure_vllm_root_logger()
assert dict_config_mock.called_with(valid_logging_config)
@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0)
def test_custom_logging_config_causes_an_error_if_configure_logging_is_off():
"""This test calls _configure_vllm_root_logger again to test custom logging
config behavior, however mocks are used to ensure no changes in behavior or
configuration occur."""
valid_logging_config = {
"loggers": {
"vllm.test_logger.logger": {
"handlers": [],
}
},
"version": 1
}
with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file:
logging_config_file.write(json.dumps(valid_logging_config))
logging_config_file.flush()
with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH",
logging_config_file.name):
with pytest.raises(RuntimeError) as ex_info:
_configure_vllm_root_logger()
assert ex_info.type is RuntimeError
expected_message_snippet = (
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
"VLLM_LOGGING_CONFIG_PATH was given.")
assert expected_message_snippet in str(ex_info)
# Remember! The root logger is assumed to have been configured as
# though VLLM_CONFIGURE_LOGGING=1 and VLLM_LOGGING_CONFIG_PATH=None.
root_logger = logging.getLogger("vllm")
other_logger_name = f"vllm.test_logger.{uuid4()}"
other_logger = init_logger(other_logger_name)
assert other_logger.handlers != root_logger.handlers
assert other_logger.level != root_logger.level
assert other_logger.propagate
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
...@@ -69,7 +70,7 @@ def test_logits_processors(seed: int, device: str): ...@@ -69,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
return logits return logits
seq_group_metadata_list = [] seq_group_metadata_list = []
prompt_lens = [] seq_lens = []
for i in range(batch_size): for i in range(batch_size):
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -80,11 +81,14 @@ def test_logits_processors(seed: int, device: str): ...@@ -80,11 +81,14 @@ def test_logits_processors(seed: int, device: str):
logits_processors=[pick_ith]), logits_processors=[pick_ith]),
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
prompt_lens, seq_group_metadata_list,
subquery_lens=prompt_lens) seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
logits_processor_output = logits_processor( logits_processor_output = logits_processor(
embedding=None, embedding=None,
hidden_states=input_tensor, hidden_states=input_tensor,
......
import pytest
from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer
TOKENIZER_NAMES = [
"facebook/opt-125m",
"gpt2",
]
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
def test_tokenizer_revision(tokenizer_name: str):
# Assume that "main" branch always exists
tokenizer = get_tokenizer(tokenizer_name, revision="main")
assert isinstance(tokenizer, PreTrainedTokenizerBase)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match='not a valid git identifier'):
get_tokenizer(tokenizer_name, revision="never")
...@@ -2,7 +2,10 @@ import pytest ...@@ -2,7 +2,10 @@ import pytest
import torch import torch
from vllm.config import ModelConfig, SchedulerConfig from vllm.config import ModelConfig, SchedulerConfig
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
...@@ -20,14 +23,14 @@ def test_prepare_prompt(batch_size): ...@@ -20,14 +23,14 @@ def test_prepare_prompt(batch_size):
lora_config=None) lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] seq_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
block_tables = {0: [1]} block_tables = {0: [1]}
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(prompt_len))) seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
...@@ -40,29 +43,29 @@ def test_prepare_prompt(batch_size): ...@@ -40,29 +43,29 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 selected_token_start_idx = 0
for prompt_len in prompt_lens: for seq_len in seq_lens:
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) seq_len - 1)
selected_token_start_idx += prompt_len selected_token_start_idx += seq_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_seq_lens == seq_lens
assert return_prompt_lens == prompt_lens
assert len(slot_mapping) == len(input_tokens) assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.is_prompt is True assert attn_metadata.is_prompt is True
assert torch.allclose(attn_metadata.prompt_lens_tensor, assert torch.allclose(
torch.tensor(prompt_lens, device=device)) attn_metadata.seq_lens_tensor,
assert attn_metadata.prompt_lens == prompt_lens torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_prompt_len == max(prompt_lens) assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_seq_len == max(seq_lens)
# Test subquery start locs. # Test subquery start locs.
start_idx = 0 start_idx = 0
start_loc = [start_idx] start_loc = [start_idx]
for prompt_len in prompt_lens: for seq_len in seq_lens:
start_idx += prompt_len start_idx += seq_len
start_loc.append(start_idx) start_loc.append(start_idx)
assert torch.allclose( assert torch.allclose(
attn_metadata.subquery_start_loc, attn_metadata.subquery_start_loc,
...@@ -72,17 +75,16 @@ def test_prepare_prompt(batch_size): ...@@ -72,17 +75,16 @@ def test_prepare_prompt(batch_size):
# equivalent to subquery_start_loc. # equivalent to subquery_start_loc.
start_idx = 0 start_idx = 0
seq_start_loc = [start_idx] seq_start_loc = [start_idx]
for prompt_len in prompt_lens: for seq_len in seq_lens:
start_idx += prompt_len start_idx += seq_len
seq_start_loc.append(start_idx) seq_start_loc.append(start_idx)
assert torch.allclose( assert torch.allclose(
attn_metadata.seq_start_loc, attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device)) torch.tensor(start_loc, dtype=torch.int32, device=device))
assert attn_metadata.max_context_len is None
assert torch.allclose( assert torch.allclose(
attn_metadata.context_lens, attn_metadata.context_lens_tensor,
torch.zeros(attn_metadata.context_lens.shape[0], torch.zeros(attn_metadata.context_lens_tensor.shape[0],
dtype=torch.int, dtype=torch.int,
device=device)) device=device))
...@@ -93,15 +95,18 @@ def test_prepare_prompt(batch_size): ...@@ -93,15 +95,18 @@ def test_prepare_prompt(batch_size):
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False assert attn_metadata.use_cuda_graph is False
assert len(input_tokens) == sum(prompt_lens) assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(prompt_lens) assert len(input_positions) == sum(seq_lens)
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
prompt_lens, seq_group_metadata_list,
subquery_lens=prompt_lens) seq_lens,
assert len(input_tokens) == sum(prompt_lens) query_lens=seq_lens,
assert len(input_positions) == sum(prompt_lens) device=model_runner.device,
pin_memory=model_runner.pin_memory)
assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
device=actual.device, device=actual.device,
...@@ -140,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -140,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size):
lora_config=None) lora_config=None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] seq_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) seq_lens.append(seq_len)
seq_data = list(range(prompt_len)) seq_data = list(range(seq_len))
seq_data = SequenceData(seq_data) seq_data = SequenceData(seq_data)
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
...@@ -166,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -166,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.is_prompt is False assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None assert attn_metadata.seq_lens is None
assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_context_len == max(prompt_lens) assert attn_metadata.max_seq_len == max(seq_lens)
assert torch.allclose( assert torch.allclose(
attn_metadata.context_lens[:len(prompt_lens)], attn_metadata.seq_lens_tensor[:len(seq_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device)) torch.tensor(seq_lens, dtype=torch.int, device=device))
# block table's first index corresponds to each batch, meaning in # block table's first index corresponds to each batch, meaning in
# decoding it is each token. # decoding it is each token.
...@@ -192,12 +196,15 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -192,12 +196,15 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify Sampling # Verify Sampling
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 selected_token_start_idx = 0
for prompt_len in prompt_lens: for seq_len in seq_lens:
expected_selected_token_indices.append(selected_token_start_idx) expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1 selected_token_start_idx += 1
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
prompt_lens, seq_group_metadata_list,
subquery_lens=prompt_lens) seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
device=actual.device, device=actual.device,
...@@ -232,29 +239,27 @@ def test_empty_seq_group(): ...@@ -232,29 +239,27 @@ def test_empty_seq_group():
assert attn_metadata is None assert attn_metadata is None
assert len(slot_mapping) == 0 assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert len(input_tokens) == 0 assert len(input_tokens) == 0
assert len(input_positions) == 0 assert len(input_positions) == 0
assert attn_metadata is None assert attn_metadata is None
assert len(slot_mapping) == 0 assert len(slot_mapping) == 0
assert len(return_prompt_lens) == 0 assert len(return_seq_lens) == 0
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
def get_world_size(group=None): @pytest.fixture
return 1 def distributed_init():
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
local_rank=0)
def mock_get_process_group_ranks(group=None):
return [0]
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) @pytest.mark.parametrize("batch_size", list(range(2, 128)))
monkeypatch.setattr(torch.distributed, "get_process_group_ranks", @pytest.mark.parametrize("enforce_eager", [True, False])
mock_get_process_group_ranks) def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
model_config = ModelConfig( model_config = ModelConfig(
"facebook/opt-125m", "facebook/opt-125m",
...@@ -280,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): ...@@ -280,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
model_runner.set_block_size(16) model_runner.set_block_size(16)
# Add prefill requests. # Add prefill requests.
prompt_lens = [] seq_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
prefill_metadata_list = [] prefill_metadata_list = []
decode_metadata_list = [] decode_metadata_list = []
...@@ -289,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): ...@@ -289,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
decode_batch_size = batch_size - prefill_batch_size decode_batch_size = batch_size - prefill_batch_size
for i in range(prefill_batch_size): for i in range(prefill_batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(prompt_len))) seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
...@@ -306,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): ...@@ -306,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
# Add decode requests # Add decode requests
for i in range(prefill_batch_size, batch_size): for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(prompt_len)) prompt_toks = list(range(seq_len))
seq_data = SequenceData(prompt_toks) seq_data = SequenceData(prompt_toks)
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
...@@ -335,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): ...@@ -335,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
else: else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size( assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size) decode_batch_size)
assert attn_metadata.num_prefill_tokens == sum(prompt_lens) assert attn_metadata.num_prefill_tokens == sum(seq_lens)
# Verify attn metadata is consistent. We don't need to test individual # Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above. # values here because they are tested above.
......
import torch import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -54,10 +55,14 @@ def test_swap() -> None: ...@@ -54,10 +55,14 @@ def test_swap() -> None:
# Test swap out. # Test swap out.
blocks_to_swap_out = {3: 72, 56: 35, 84: 34} blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
worker.execute_model(seq_group_metadata_list=[], execute_model_req = ExecuteModelRequest(
blocks_to_swap_in={}, seq_group_metadata_list=[],
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_in={},
blocks_to_copy={}) blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy={},
)
worker.execute_model(execute_model_req=execute_model_req)
for i in range(num_layers): for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i] gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i]
...@@ -66,14 +71,19 @@ def test_swap() -> None: ...@@ -66,14 +71,19 @@ def test_swap() -> None:
assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
# Test swap in. # Test swap in.
blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71} execute_model_req.blocks_to_swap_out = {}
worker.execute_model(seq_group_metadata_list=[], execute_model_req.blocks_to_swap_in = {
blocks_to_swap_in=blocks_to_swap_in, 19: 45,
blocks_to_swap_out={}, 67: 23,
blocks_to_copy={}) 12: 78,
40: 99,
1: 71
}
worker.execute_model(execute_model_req=execute_model_req)
for i in range(num_layers): for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i] gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_in.items(): for src, dst in execute_model_req.blocks_to_swap_in.items():
assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment