Unverified Commit 3209b490 authored by Nikola Borisov's avatar Nikola Borisov Committed by GitHub
Browse files

[Bugfix] fix crash if max_tokens=None (#2570)

parent 1e4277d2
...@@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group(): ...@@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group():
assert len(prompts) == len(outputs) assert len(prompts) == len(outputs)
def test_max_tokens_none():
sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=None)
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = ["Just say hello!"]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(prompts) == len(outputs)
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
pytest.main([__file__]) pytest.main([__file__])
"""Tests for the SamplingParams class.
"""
from vllm import SamplingParams
def test_max_tokens_none():
"""max_tokens=None should be allowed"""
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
...@@ -108,7 +108,7 @@ class SamplingParams: ...@@ -108,7 +108,7 @@ class SamplingParams:
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False, include_stop_str_in_output: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: Optional[int] = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
...@@ -183,7 +183,7 @@ class SamplingParams: ...@@ -183,7 +183,7 @@ class SamplingParams:
if not 0.0 <= self.min_p <= 1.0: if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got " raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.") f"{self.min_p}.")
if self.max_tokens < 1: if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError( raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.") f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.logprobs is not None and self.logprobs < 0: if self.logprobs is not None and self.logprobs < 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment