Unverified Commit 18b296fd authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core] remove beam search from the core (#9105)

parent c8f26bb6
...@@ -23,7 +23,6 @@ class RequestFuncInput: ...@@ -23,7 +23,6 @@ class RequestFuncInput:
output_len: int output_len: int
model: str model: str
best_of: int = 1 best_of: int = 1
use_beam_search: bool = False
logprobs: Optional[int] = None logprobs: Optional[int] = None
multi_modal_content: Optional[dict] = None multi_modal_content: Optional[dict] = None
ignore_eos: bool = False ignore_eos: bool = False
...@@ -49,7 +48,6 @@ async def async_request_tgi( ...@@ -49,7 +48,6 @@ async def async_request_tgi(
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
params = { params = {
"best_of": request_func_input.best_of, "best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
...@@ -121,7 +119,6 @@ async def async_request_trt_llm( ...@@ -121,7 +119,6 @@ async def async_request_trt_llm(
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
assert request_func_input.best_of == 1 assert request_func_input.best_of == 1
payload = { payload = {
"accumulate_tokens": True, "accumulate_tokens": True,
...@@ -187,7 +184,6 @@ async def async_request_deepspeed_mii( ...@@ -187,7 +184,6 @@ async def async_request_deepspeed_mii(
) -> RequestFuncOutput: ) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1 assert request_func_input.best_of == 1
assert not request_func_input.use_beam_search
payload = { payload = {
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
...@@ -235,7 +231,6 @@ async def async_request_openai_completions( ...@@ -235,7 +231,6 @@ async def async_request_openai_completions(
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
payload = { payload = {
"model": request_func_input.model, "model": request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
...@@ -317,7 +312,6 @@ async def async_request_openai_chat_completions( ...@@ -317,7 +312,6 @@ async def async_request_openai_chat_completions(
), "OpenAI Chat Completions API URL must end with 'chat/completions'." ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content: if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content) content.append(request_func_input.multi_modal_content)
......
...@@ -51,9 +51,8 @@ def main(args: argparse.Namespace): ...@@ -51,9 +51,8 @@ def main(args: argparse.Namespace):
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
temperature=0.0 if args.use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=args.output_len, max_tokens=args.output_len,
) )
......
...@@ -68,7 +68,6 @@ def run_vllm( ...@@ -68,7 +68,6 @@ def run_vllm(
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
...@@ -114,9 +113,8 @@ def run_vllm( ...@@ -114,9 +113,8 @@ def run_vllm(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
...@@ -144,15 +142,16 @@ def main(args: argparse.Namespace): ...@@ -144,15 +142,16 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( elapsed_time = run_vllm(requests, args.model, args.tokenizer,
requests, args.model, args.tokenizer, args.quantization, args.quantization, args.tensor_parallel_size,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.seed, args.n, args.trust_remote_code,
args.trust_remote_code, args.dtype, args.max_model_len, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device, args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill, args.enable_prefix_caching,
args.max_num_batched_tokens, args.gpu_memory_utilization, args.enable_chunked_prefill,
args.download_dir) args.max_num_batched_tokens,
args.gpu_memory_utilization, args.download_dir)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(prompt_len + output_len
...@@ -203,7 +202,6 @@ if __name__ == "__main__": ...@@ -203,7 +202,6 @@ if __name__ == "__main__":
type=int, type=int,
default=1, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", parser.add_argument("--num-prompts",
type=int, type=int,
default=200, default=200,
......
...@@ -391,7 +391,6 @@ async def benchmark( ...@@ -391,7 +391,6 @@ async def benchmark(
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
logprobs: Optional[int], logprobs: Optional[int],
best_of: int, best_of: int,
use_beam_search: bool,
request_rate: float, request_rate: float,
disable_tqdm: bool, disable_tqdm: bool,
profile: bool, profile: bool,
...@@ -419,7 +418,6 @@ async def benchmark( ...@@ -419,7 +418,6 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
) )
...@@ -441,7 +439,6 @@ async def benchmark( ...@@ -441,7 +439,6 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
) )
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
...@@ -464,7 +461,6 @@ async def benchmark( ...@@ -464,7 +461,6 @@ async def benchmark(
output_len=output_len, output_len=output_len,
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=mm_content, multi_modal_content=mm_content,
) )
tasks.append( tasks.append(
...@@ -483,7 +479,6 @@ async def benchmark( ...@@ -483,7 +479,6 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
use_beam_search=use_beam_search,
) )
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: if profile_output.success:
...@@ -679,7 +674,6 @@ def main(args: argparse.Namespace): ...@@ -679,7 +674,6 @@ def main(args: argparse.Namespace):
input_requests=input_requests, input_requests=input_requests,
logprobs=args.logprobs, logprobs=args.logprobs,
best_of=args.best_of, best_of=args.best_of,
use_beam_search=args.use_beam_search,
request_rate=args.request_rate, request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
...@@ -701,7 +695,6 @@ def main(args: argparse.Namespace): ...@@ -701,7 +695,6 @@ def main(args: argparse.Namespace):
result_json["model_id"] = model_id result_json["model_id"] = model_id
result_json["tokenizer_id"] = tokenizer_id result_json["tokenizer_id"] = tokenizer_id
result_json["best_of"] = args.best_of result_json["best_of"] = args.best_of
result_json["use_beam_search"] = args.use_beam_search
result_json["num_prompts"] = args.num_prompts result_json["num_prompts"] = args.num_prompts
# Metadata # Metadata
......
...@@ -73,7 +73,6 @@ def run_vllm( ...@@ -73,7 +73,6 @@ def run_vllm(
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
...@@ -91,7 +90,6 @@ def run_vllm( ...@@ -91,7 +90,6 @@ def run_vllm(
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format, load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
use_new_beam_search_impl: bool = False,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
...@@ -127,19 +125,19 @@ def run_vllm( ...@@ -127,19 +125,19 @@ def run_vllm(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
if not use_new_beam_search_impl: use_beam_search = False
if not use_beam_search:
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True) llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
else: else:
assert use_beam_search
prompts = [prompt for prompt, _, _ in requests] prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests. # output_len should be the same for all requests.
output_len = requests[0][2] output_len = requests[0][2]
...@@ -165,7 +163,6 @@ async def run_vllm_async( ...@@ -165,7 +163,6 @@ async def run_vllm_async(
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
...@@ -224,9 +221,8 @@ async def run_vllm_async( ...@@ -224,9 +221,8 @@ async def run_vllm_async(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
...@@ -248,11 +244,9 @@ def run_hf( ...@@ -248,11 +244,9 @@ def run_hf(
model: str, model: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
n: int, n: int,
use_beam_search: bool,
max_batch_size: int, max_batch_size: int,
trust_remote_code: bool, trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
...@@ -284,7 +278,7 @@ def run_hf( ...@@ -284,7 +278,7 @@ def run_hf(
padding=True).input_ids padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=not use_beam_search, do_sample=True,
num_return_sequences=n, num_return_sequences=n,
temperature=1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
...@@ -340,7 +334,7 @@ def main(args: argparse.Namespace): ...@@ -340,7 +334,7 @@ def main(args: argparse.Namespace):
if args.backend == "vllm": if args.backend == "vllm":
run_args = [ run_args = [
requests, args.model, args.tokenizer, args.quantization, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device, args.quantization_param_path, args.device,
...@@ -355,12 +349,11 @@ def main(args: argparse.Namespace): ...@@ -355,12 +349,11 @@ def main(args: argparse.Namespace):
run_args.append(args.disable_frontend_multiprocessing) run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args)) elapsed_time = uvloop.run(run_vllm_async(*run_args))
else: else:
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl) elapsed_time = run_vllm(*run_args)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size, args.hf_max_batch_size, args.trust_remote_code)
args.trust_remote_code)
elif args.backend == "mii": elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len) args.output_len)
...@@ -414,8 +407,6 @@ if __name__ == "__main__": ...@@ -414,8 +407,6 @@ if __name__ == "__main__":
type=int, type=int,
default=1, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts", parser.add_argument("--num-prompts",
type=int, type=int,
default=1000, default=1000,
...@@ -570,8 +561,6 @@ if __name__ == "__main__": ...@@ -570,8 +561,6 @@ if __name__ == "__main__":
raise ValueError("dtype must be auto for MII backend.") raise ValueError("dtype must be auto for MII backend.")
if args.n != 1: if args.n != 1:
raise ValueError("n must be 1 for MII backend.") raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None: if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:
......
...@@ -18,9 +18,6 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]: ...@@ -18,9 +18,6 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
frequency_penalty=0.1)), frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
] ]
......
...@@ -43,15 +43,6 @@ def create_test_prompts( ...@@ -43,15 +43,6 @@ def create_test_prompts(
max_tokens=128, max_tokens=128,
stop_token_ids=[32003]), stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)), LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
( (
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0, SamplingParams(temperature=0.0,
...@@ -60,15 +51,6 @@ def create_test_prompts( ...@@ -60,15 +51,6 @@ def create_test_prompts(
max_tokens=128, max_tokens=128,
stop_token_ids=[32003]), stop_token_ids=[32003]),
LoRARequest("sql-lora2", 2, lora_path)), LoRARequest("sql-lora2", 2, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
] ]
......
...@@ -23,11 +23,9 @@ MODELS = [ ...@@ -23,11 +23,9 @@ MODELS = [
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def check_settings(): def check_settings():
assert ENABLE_ARTIFICIAL_PREEMPT is True, ( assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, " "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1."
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 " "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest " "pytest tests/basic_correctness/test_preemption.py`")
"tests/basic_correctness/test_preemption.py`")
@pytest.fixture @pytest.fixture
...@@ -137,114 +135,6 @@ def test_preemption( ...@@ -137,114 +135,6 @@ def test_preemption(
assert total_preemption == total_recorded_preemption assert total_preemption == total_recorded_preemption
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
caplog_vllm,
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
with vllm_runner(
model,
dtype=dtype,
swap_space=10,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")
assert ("is preempted by PreemptionMode.SWAP mode because there "
"is not enough KV cache space." in caplog_vllm.text)
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
preemption_metrics = None
for m in REGISTRY.collect():
if m.name == "vllm:num_preemptions":
preemption_metrics = m
assert preemption_metrics is not None
total_recorded_preemption = 0
for sample in preemption_metrics.samples:
total_recorded_preemption += sample.value
assert total_preemption == total_recorded_preemption
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
@pytest.mark.parametrize("use_v2_block_manager", [True, False])
def test_swap_infeasible(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
use_v2_block_manager: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]
with vllm_runner(
model,
dtype=dtype,
swap_space=10,
block_size=BLOCK_SIZE,
# Since beam search have more than 1 sequence, prefill +
# decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens,
ignore_eos=True)
req_outputs = vllm_model.model.generate(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
# Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length"
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
......
...@@ -782,7 +782,6 @@ class VllmRunner: ...@@ -782,7 +782,6 @@ class VllmRunner:
List[TokensTextLogprobsPromptLogprobs]]: List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams( greedy_logprobs_params = SamplingParams(
temperature=0.0, temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs, logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs), prompt_logprobs=(num_prompt_logprobs),
...@@ -795,19 +794,6 @@ class VllmRunner: ...@@ -795,19 +794,6 @@ class VllmRunner:
encoder_decoder_prompts, greedy_logprobs_params) encoder_decoder_prompts, greedy_logprobs_params)
def generate_beam_search( def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens)
outputs = self.generate(prompts, beam_search_params)
return outputs
def generate_beam_search_new(
self, self,
prompts: Union[List[str], List[List[int]]], prompts: Union[List[str], List[List[int]]],
beam_width: int, beam_width: int,
......
...@@ -85,73 +85,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, ...@@ -85,73 +85,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Use a large block size to trigger more copy-on-writes.
"block_size": 32,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify beam search equality with block manager v1 and v2.
This requires copy-on-writes; if the v1 and v2 output is the same, then
we have some confidence cow is working.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
use_beam_search=True,
best_of=2,
)
print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
......
...@@ -13,7 +13,6 @@ def create_dummy_prompt( ...@@ -13,7 +13,6 @@ def create_dummy_prompt(
prompt_length: int, prompt_length: int,
block_size: Optional[int] = None, block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1, best_of: int = 1,
prompt_tokens: Optional[List[int]] = None, prompt_tokens: Optional[List[int]] = None,
min_tokens: int = 0, min_tokens: int = 0,
...@@ -37,7 +36,6 @@ def create_dummy_prompt( ...@@ -37,7 +36,6 @@ def create_dummy_prompt(
seqs=[prompt], seqs=[prompt],
arrival_time=time.time(), arrival_time=time.time(),
sampling_params=SamplingParams( sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of, best_of=best_of,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=min_tokens), min_tokens=min_tokens),
...@@ -52,7 +50,6 @@ def create_dummy_prompt_encoder_decoder( ...@@ -52,7 +50,6 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_length: int, encoder_prompt_length: int,
block_size: Optional[int] = None, block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1, best_of: int = 1,
) -> Tuple[Sequence, Sequence, SequenceGroup]: ) -> Tuple[Sequence, Sequence, SequenceGroup]:
if not block_size: if not block_size:
...@@ -85,9 +82,7 @@ def create_dummy_prompt_encoder_decoder( ...@@ -85,9 +82,7 @@ def create_dummy_prompt_encoder_decoder(
from_decoder_prompt=False) from_decoder_prompt=False)
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt], seqs=[decoder_prompt],
sampling_params=SamplingParams( sampling_params=SamplingParams(best_of=best_of),
use_beam_search=use_beam_search,
best_of=best_of),
arrival_time=time.time(), arrival_time=time.time(),
lora_request=lora_request, lora_request=lora_request,
encoder_seq=encoder_prompt) encoder_seq=encoder_prompt)
......
...@@ -33,8 +33,8 @@ def test_beam_search_single_input( ...@@ -33,8 +33,8 @@ def test_beam_search_single_input(
max_tokens) max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search_new( vllm_outputs = vllm_model.generate_beam_search(example_prompts,
example_prompts, beam_width, max_tokens) beam_width, max_tokens)
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i] hf_output_ids, hf_output_texts = hf_outputs[i]
......
...@@ -159,26 +159,6 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): ...@@ -159,26 +159,6 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
assert first_sampler_output == second_sampler_output assert first_sampler_output == second_sampler_output
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, sampling_params, device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str): def test_sampler_min_tokens_penalty(seed: int, device: str):
...@@ -479,7 +459,7 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -479,7 +459,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_lens: List[int] = [] seq_lens: List[int] = []
for i in range(batch_size): for i in range(batch_size):
expected: Optional[List[int]] = None expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3) sampling_type = random.randint(0, 2)
if sampling_type == 0: if sampling_type == 0:
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())] expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
...@@ -498,10 +478,7 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -498,10 +478,7 @@ def test_sampler_mixed(seed: int, device: str):
for idx in range(n): for idx in range(n):
fake_logits[i, i + idx] = 1e2 fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n)) expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
expected_tokens.append(expected) expected_tokens.append(expected)
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
...@@ -530,9 +507,6 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -530,9 +507,6 @@ def test_sampler_mixed(seed: int, device: str):
zip(sampler_output, seq_group_metadata_list)): zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None assert metadata.sampling_params is not None
if metadata.sampling_params.use_beam_search:
continue
if (metadata.sampling_params.seed is not None if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None): and expected_tokens[i] is None):
# Record seeded random result to compare with results of # Record seeded random result to compare with results of
......
...@@ -1202,9 +1202,9 @@ class Scheduler: ...@@ -1202,9 +1202,9 @@ class Scheduler:
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
# TODO: does it work with parallel sampling?
no_beam_search = seq_group.sampling_params is None or ( no_beam_search = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1 seq_group.sampling_params.best_of == 1)
and not seq_group.sampling_params.use_beam_search)
return no_beam_search return no_beam_search
def schedule( def schedule(
......
...@@ -33,7 +33,7 @@ from vllm.sequence import ExecuteModelRequest ...@@ -33,7 +33,7 @@ from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (collect_from_async_generator, deprecate_kwargs, from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
random_uuid, weak_bind) get_beam_search_score, random_uuid, weak_bind)
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
...@@ -1050,6 +1050,12 @@ class AsyncLLMEngine: ...@@ -1050,6 +1050,12 @@ class AsyncLLMEngine:
max_tokens = params.max_tokens max_tokens = params.max_tokens
ignore_eos = params.ignore_eos ignore_eos = params.ignore_eos
temperature = params.temperature temperature = params.temperature
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = await self.get_tokenizer() tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance( tokenizedPrompt = prompt if isinstance(
...@@ -1103,15 +1109,11 @@ class AsyncLLMEngine: ...@@ -1103,15 +1109,11 @@ class AsyncLLMEngine:
else: else:
new_beams.append(new_beam) new_beams.append(new_beam)
sorted_beams = sorted(new_beams, sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
key=lambda x: x.cum_logprob,
reverse=True)
all_beams = sorted_beams[:beam_width] all_beams = sorted_beams[:beam_width]
completed.extend(all_beams) completed.extend(all_beams)
sorted_completed = sorted(completed, sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
key=lambda x: x.cum_logprob,
reverse=True)
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
for beam in best_beams: for beam in best_beams:
......
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Tuple
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
...@@ -6,7 +6,6 @@ from vllm.engine.output_processor.interfaces import ( ...@@ -6,7 +6,6 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
...@@ -113,7 +112,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -113,7 +112,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
outputs: SequenceGroupOutput, outputs: SequenceGroupOutput,
is_async: bool) -> None: is_async: bool) -> None:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
if sampling_params.best_of == 1 and not sampling_params.use_beam_search: if sampling_params.best_of == 1:
# only have one output sample # only have one output sample
sample = outputs.samples[0] sample = outputs.samples[0]
# only have one sequence # only have one sequence
...@@ -142,7 +141,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -142,7 +141,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Process samples # Process samples
samples = outputs.samples samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = { parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: [] parent_seq.seq_id: []
for parent_seq in parent_seqs for parent_seq in parent_seqs
...@@ -197,8 +195,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -197,8 +195,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
lora_req=seq_group.lora_request, lora_req=seq_group.lora_request,
) )
# Non-beam search case
if not sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group # For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished. # and fork them in block manager if they are not finished.
for seq, parent in child_seqs: for seq, parent in child_seqs:
...@@ -217,149 +213,3 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -217,149 +213,3 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
for scheduler in self.scheduler: for scheduler in self.scheduler:
scheduler.free_seq(seq) scheduler.free_seq(seq)
return return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = sampling_params.best_of
length_penalty = sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
sampling_params.early_stopping, sampling_params,
best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score
...@@ -28,7 +28,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, ...@@ -28,7 +28,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
is_list_of)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -404,6 +405,12 @@ class LLM: ...@@ -404,6 +405,12 @@ class LLM:
max_tokens = params.max_tokens max_tokens = params.max_tokens
temperature = params.temperature temperature = params.temperature
ignore_eos = params.ignore_eos ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
...@@ -466,7 +473,7 @@ class LLM: ...@@ -466,7 +473,7 @@ class LLM:
else: else:
instance_new_beams.append(new_beam) instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams, sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob, key=sort_beams_key,
reverse=True) reverse=True)
instance.beams = sorted_beams[:beam_width] instance.beams = sorted_beams[:beam_width]
...@@ -474,7 +481,7 @@ class LLM: ...@@ -474,7 +481,7 @@ class LLM:
for instance in instances: for instance in instances:
instance.completed.extend(instance.beams) instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed, sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob, key=sort_beams_key,
reverse=True) reverse=True)
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
......
...@@ -184,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -184,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_p: float = 0.0 min_p: float = 0.0
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
length_penalty: float = 1.0 length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
ignore_eos: bool = False ignore_eos: bool = False
...@@ -302,6 +301,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -302,6 +301,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
max_tokens=max_tokens, max_tokens=max_tokens,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty,
) )
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
...@@ -345,12 +345,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -345,12 +345,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
...@@ -518,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -518,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel):
min_p: float = 0.0 min_p: float = 0.0
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
length_penalty: float = 1.0 length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
ignore_eos: bool = False ignore_eos: bool = False
...@@ -597,6 +593,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -597,6 +593,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens=max_tokens, max_tokens=max_tokens,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty,
) )
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
...@@ -641,13 +638,10 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -641,13 +638,10 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1, max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
......
...@@ -63,7 +63,6 @@ if TYPE_CHECKING: ...@@ -63,7 +63,6 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
...@@ -198,10 +197,6 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -198,10 +197,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")), ("true", "1")),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1",
# Internal flag to enable Dynamo graph capture # Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE": "VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
......
...@@ -947,8 +947,6 @@ def get_logprobs( ...@@ -947,8 +947,6 @@ def get_logprobs(
# largest num logprobs in this API. If every logprobs is None, it will be # largest num logprobs in this API. If every logprobs is None, it will be
# set to -1. # set to -1.
largest_num_logprobs = -1 largest_num_logprobs = -1
# If beam search is enabled.
use_beam_search = False
# Select indices to compute logprob from, ranks of token ids, and the top # Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs. # k token ids from logprobs.
...@@ -981,8 +979,6 @@ def get_logprobs( ...@@ -981,8 +979,6 @@ def get_logprobs(
largest_num_logprobs = max(largest_num_logprobs, largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs) sampling_params.logprobs)
use_beam_search = use_beam_search or sampling_params.use_beam_search
assert len(next_token_ids) == len(query_indices) assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0: if len(query_indices) == 0:
...@@ -995,7 +991,7 @@ def get_logprobs( ...@@ -995,7 +991,7 @@ def get_logprobs(
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation. # skip the whole logprob calculation.
if largest_num_logprobs >= 0 or use_beam_search: if largest_num_logprobs >= 0:
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, next_token_ids_gpu = torch.tensor(next_token_ids,
device=logprobs.device) device=logprobs.device)
...@@ -1121,13 +1117,12 @@ def _get_sampled_logprob_if_needed( ...@@ -1121,13 +1117,12 @@ def _get_sampled_logprob_if_needed(
"""Compute the sample logprob if needed.""" """Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs num_logprobs = seq_group.sampling_params.logprobs
use_beam_search = seq_group.sampling_params.use_beam_search
sampled_logprobs: SampleLogprobs = [] sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample: if seq_group.do_sample:
assert len(next_token_ids) > 0 assert len(next_token_ids) > 0
if num_logprobs is None and not use_beam_search: if num_logprobs is None:
for next_token_id in next_token_ids: for next_token_id in next_token_ids:
# Use a dummy logprob # Use a dummy logprob
sampled_logprobs.append({next_token_id: Logprob(inf)}) sampled_logprobs.append({next_token_id: Logprob(inf)})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment