Unverified Commit d9784107 authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[Misc] unify variable for LLM instance (#20996)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent e6b90a28
...@@ -107,11 +107,11 @@ def test_quark_fp8_parity(vllm_runner): ...@@ -107,11 +107,11 @@ def test_quark_fp8_parity(vllm_runner):
} }
with (vllm_runner(quark_model_id, **llm_kwargs) as with (vllm_runner(quark_model_id, **llm_kwargs) as
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle): quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
quark_model = (quark_handle.model.llm_engine.model_executor. quark_model = (quark_handle.llm.llm_engine.model_executor.
driver_worker.model_runner.model) driver_worker.model_runner.model)
quark_state_dict = quark_model.state_dict() quark_state_dict = quark_model.state_dict()
fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker. fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker.
model_runner.model) model_runner.model)
fp8_state_dict = fp8_model.state_dict() fp8_state_dict = fp8_model.state_dict()
......
...@@ -111,7 +111,7 @@ def test_custom_quant(vllm_runner, model, monkeypatch): ...@@ -111,7 +111,7 @@ def test_custom_quant(vllm_runner, model, monkeypatch):
quantization="custom_quant", quantization="custom_quant",
enforce_eager=True) as llm: enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0] layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj qkv_proj = layer.self_attn.qkv_proj
......
...@@ -36,7 +36,7 @@ def test_ignore_eos( ...@@ -36,7 +36,7 @@ def test_ignore_eos(
ignore_eos=True) ignore_eos=True)
for prompt in example_prompts: for prompt in example_prompts:
ignore_eos_output = vllm_model.model.generate( ignore_eos_output = vllm_model.llm.generate(
prompt, sampling_params=sampling_params) prompt, sampling_params=sampling_params)
output_length = len(ignore_eos_output[0].outputs[0].token_ids) output_length = len(ignore_eos_output[0].outputs[0].token_ids)
assert output_length == max_tokens assert output_length == max_tokens
...@@ -26,7 +26,7 @@ def test_logits_processor_force_generate( ...@@ -26,7 +26,7 @@ def test_logits_processor_force_generate(
dtype: str, dtype: str,
) -> None: ) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer() tokenizer = vllm_model.llm.get_tokenizer()
repeat_times = 2 repeat_times = 2
enforced_answers = " vLLM" enforced_answers = " vLLM"
vllm_token_ids = tokenizer.encode(enforced_answers, vllm_token_ids = tokenizer.encode(enforced_answers,
...@@ -45,13 +45,13 @@ def test_logits_processor_force_generate( ...@@ -45,13 +45,13 @@ def test_logits_processor_force_generate(
) )
# test logits_processors when prompt_logprobs is not None # test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request( vllm_model.llm._add_request(
example_prompts[0], example_prompts[0],
params=params_with_logprobs, params=params_with_logprobs,
) )
# test prompt_logprobs is not None # test prompt_logprobs is not None
vllm_model.model._add_request( vllm_model.llm._add_request(
example_prompts[1], example_prompts[1],
params=SamplingParams( params=SamplingParams(
prompt_logprobs=3, prompt_logprobs=3,
...@@ -60,11 +60,11 @@ def test_logits_processor_force_generate( ...@@ -60,11 +60,11 @@ def test_logits_processor_force_generate(
) )
# test grouped requests # test grouped requests
vllm_model.model._add_request( vllm_model.llm._add_request(
example_prompts[2], example_prompts[2],
params=SamplingParams(max_tokens=max_tokens), params=SamplingParams(max_tokens=max_tokens),
) )
outputs = vllm_model.model._run_engine(use_tqdm=False) outputs = vllm_model.llm._run_engine(use_tqdm=False)
assert outputs[0].outputs[0].text == enforced_answers * repeat_times assert outputs[0].outputs[0].text == enforced_answers * repeat_times
...@@ -64,7 +64,7 @@ def test_get_prompt_logprobs( ...@@ -64,7 +64,7 @@ def test_get_prompt_logprobs(
prompt_logprobs=num_top_logprobs, prompt_logprobs=num_top_logprobs,
temperature=0.0, temperature=0.0,
detokenize=detokenize) detokenize=detokenize)
vllm_results = vllm_model.model.generate( vllm_results = vllm_model.llm.generate(
example_prompts, sampling_params=vllm_sampling_params) example_prompts, sampling_params=vllm_sampling_params)
# Test whether logprobs are included in the results. # Test whether logprobs are included in the results.
...@@ -174,7 +174,7 @@ def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int, ...@@ -174,7 +174,7 @@ def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
logprobs=None, logprobs=None,
temperature=0.0, temperature=0.0,
detokenize=detokenize) detokenize=detokenize)
results_logprobs_none = vllm_model.model.generate( results_logprobs_none = vllm_model.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_none) example_prompts, sampling_params=sampling_params_logprobs_none)
for i in range(len(results_logprobs_none)): for i in range(len(results_logprobs_none)):
......
...@@ -20,7 +20,7 @@ def v1(run_with_both_engines): ...@@ -20,7 +20,7 @@ def v1(run_with_both_engines):
def _generate( def _generate(
model: LLM, llm: LLM,
prompt: str, prompt: str,
num_prompt_tokens: int, num_prompt_tokens: int,
temperature: float = 0, temperature: float = 0,
...@@ -32,7 +32,7 @@ def _generate( ...@@ -32,7 +32,7 @@ def _generate(
) )
# [([output_token_ids, ], [output_text, ]), ] # [([output_token_ids, ], [output_text, ]), ]
output = model.generate([prompt], sampling_params=sampling_params) output = llm.generate([prompt], sampling_params=sampling_params)
output_token_ids = output[0][0][0][num_prompt_tokens:] output_token_ids = output[0][0][0][num_prompt_tokens:]
# [0] first (and only) request output # [0] first (and only) request output
...@@ -66,10 +66,10 @@ class TestOneTokenBadWord: ...@@ -66,10 +66,10 @@ class TestOneTokenBadWord:
assert self.target_token_id not in output_token_ids assert self.target_token_id not in output_token_ids
def _generate(self, def _generate(self,
model: LLM, llm: LLM,
bad_words: Optional[list[str]] = None) -> list[int]: bad_words: Optional[list[str]] = None) -> list[int]:
return _generate( return _generate(
model=model, llm=llm,
prompt=self.PROMPT, prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens, num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words, bad_words=bad_words,
...@@ -156,10 +156,10 @@ class TestTwoTokenBadWord: ...@@ -156,10 +156,10 @@ class TestTwoTokenBadWord:
or (self.neighbour_token_id2 in output_token_ids)) or (self.neighbour_token_id2 in output_token_ids))
def _generate(self, def _generate(self,
model: LLM, llm: LLM,
bad_words: Optional[list[str]] = None) -> list[int]: bad_words: Optional[list[str]] = None) -> list[int]:
return _generate( return _generate(
model=model, llm=llm,
prompt=self.PROMPT, prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens, num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words, bad_words=bad_words,
......
...@@ -49,7 +49,7 @@ def test_random_sample_with_seed( ...@@ -49,7 +49,7 @@ def test_random_sample_with_seed(
sampling_params_seed_2 = copy.deepcopy(sampling_params) sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200 sampling_params_seed_2.seed = 200
llm = vllm_model.model llm = vllm_model.llm
for prompt in example_prompts: for prompt in example_prompts:
for params in ( for params in (
......
...@@ -393,7 +393,7 @@ def test_decode_prompt_logprobs_chunked_prefill( ...@@ -393,7 +393,7 @@ def test_decode_prompt_logprobs_chunked_prefill(
logprobs=5, logprobs=5,
prompt_logprobs=5, prompt_logprobs=5,
temperature=0.0) temperature=0.0)
vllm_results = vllm_model.model.generate( vllm_results = vllm_model.llm.generate(
example_prompts, sampling_params=vllm_sampling_params) example_prompts, sampling_params=vllm_sampling_params)
for idx, result in enumerate(vllm_results): for idx, result in enumerate(vllm_results):
......
...@@ -14,7 +14,7 @@ PROMPT = "Hello my name is Robert and I" ...@@ -14,7 +14,7 @@ PROMPT = "Hello my name is Robert and I"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def model() -> LLM: def llm() -> LLM:
return LLM(MODEL, return LLM(MODEL,
enforce_eager=True, enforce_eager=True,
enable_prefix_caching=True, enable_prefix_caching=True,
...@@ -24,16 +24,16 @@ def model() -> LLM: ...@@ -24,16 +24,16 @@ def model() -> LLM:
block_size=16) block_size=16)
def test_concurrent_partial_prefill(model): def test_concurrent_partial_prefill(llm):
outputs = model.generate([PROMPT] * 3) outputs = llm.generate([PROMPT] * 3)
assert len(outputs) == 3 assert len(outputs) == 3
for output in outputs: for output in outputs:
assert len(output.outputs) == 1 assert len(output.outputs) == 1
def test_prefix_cache_stats_is_recorded(model): def test_prefix_cache_stats_is_recorded(llm):
# 17 tokens will make sure first 16 tokens are cached in a block # 17 tokens will make sure first 16 tokens are cached in a block
input_tokens = {"prompt_token_ids": [101] * 17} input_tokens = {"prompt_token_ids": [101] * 17}
_ = model.generate([input_tokens]) _ = llm.generate([input_tokens])
outputs = model.generate([input_tokens]) outputs = llm.generate([input_tokens])
assert outputs[0].num_cached_tokens == 16 assert outputs[0].num_cached_tokens == 16
...@@ -112,9 +112,9 @@ def test_compatibility_with_skip_tokenizer_init( ...@@ -112,9 +112,9 @@ def test_compatibility_with_skip_tokenizer_init(
example_prompts, example_prompts,
structured_outputs=True, structured_outputs=True,
) )
model: LLM = vllm_model_skip_tokenizer_init.model llm: LLM = vllm_model_skip_tokenizer_init.llm
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = model.generate(example_prompts, sampling_params_list) _ = llm.generate(example_prompts, sampling_params_list)
def test_parallel_sampling(vllm_model, example_prompts) -> None: def test_parallel_sampling(vllm_model, example_prompts) -> None:
...@@ -125,8 +125,8 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: ...@@ -125,8 +125,8 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
example_prompt: test fixture providing prompts for testing. example_prompt: test fixture providing prompts for testing.
""" """
sampling_params_list, n_list = _get_test_sampling_params(example_prompts) sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
model: LLM = vllm_model.model llm: LLM = vllm_model.llm
outputs = model.generate(example_prompts, sampling_params_list) outputs = llm.generate(example_prompts, sampling_params_list)
# Validate each request response # Validate each request response
for out, n in zip(outputs, n_list): for out, n in zip(outputs, n_list):
...@@ -166,10 +166,10 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): ...@@ -166,10 +166,10 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
speculative_config=speculative_config, speculative_config=speculative_config,
disable_log_stats=False, disable_log_stats=False,
) as vllm_model: ) as vllm_model:
model: LLM = vllm_model.model llm: LLM = vllm_model.llm
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens) max_tokens=max_tokens)
outputs = model.generate(example_prompts, sampling_params) outputs = llm.generate(example_prompts, sampling_params)
n_prompts = len(example_prompts) n_prompts = len(example_prompts)
assert len(outputs) == n_prompts assert len(outputs) == n_prompts
...@@ -180,7 +180,7 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): ...@@ -180,7 +180,7 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
total_tokens += len(out.outputs[0].token_ids) total_tokens += len(out.outputs[0].token_ids)
assert total_tokens == max_tokens * n_prompts assert total_tokens == max_tokens * n_prompts
metrics = model.get_metrics() metrics = llm.get_metrics()
def find_metric(name) -> list[Metric]: def find_metric(name) -> list[Metric]:
found = [] found = []
......
...@@ -112,7 +112,7 @@ def _run_and_validate( ...@@ -112,7 +112,7 @@ def _run_and_validate(
max_tokens: int, max_tokens: int,
do_apc: bool, do_apc: bool,
) -> None: ) -> None:
vllm_results = vllm_model.model.generate( vllm_results = vllm_model.llm.generate(
test_prompts, sampling_params=vllm_sampling_params) test_prompts, sampling_params=vllm_sampling_params)
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip( for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
...@@ -288,7 +288,7 @@ def test_get_logprobs_and_prompt_logprobs( ...@@ -288,7 +288,7 @@ def test_get_logprobs_and_prompt_logprobs(
""" """
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0 if do_apc and (temperature < 2.0
or batch_logprobs_composition != SAMPLE_PROMPT): or batch_logprobs_composition != SAMPLE_PROMPT):
# Skip some test-cases to save time. # Skip some test-cases to save time.
...@@ -378,7 +378,7 @@ def test_none_logprobs(vllm_model, example_prompts, ...@@ -378,7 +378,7 @@ def test_none_logprobs(vllm_model, example_prompts,
prompt_logprobs=None, prompt_logprobs=None,
temperature=0.0, temperature=0.0,
) )
results_logprobs_none = vllm_model.model.generate( results_logprobs_none = vllm_model.llm.generate(
example_prompts, example_prompts,
sampling_params=sampling_params_logprobs_none, sampling_params=sampling_params_logprobs_none,
) )
...@@ -408,7 +408,7 @@ def test_zero_logprobs(vllm_model, example_prompts, ...@@ -408,7 +408,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
logprobs=0, logprobs=0,
prompt_logprobs=0, prompt_logprobs=0,
temperature=0.0) temperature=0.0)
results_logprobs_zero = vllm_model.model.generate( results_logprobs_zero = vllm_model.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_zero) example_prompts, sampling_params=sampling_params_logprobs_zero)
for i in range(len(results_logprobs_zero)): for i in range(len(results_logprobs_zero)):
......
...@@ -14,30 +14,30 @@ PROMPT = "Hello my name is Robert and I" ...@@ -14,30 +14,30 @@ PROMPT = "Hello my name is Robert and I"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def model() -> LLM: def llm() -> LLM:
# Disable prefix caching so that we can test prompt logprobs. # Disable prefix caching so that we can test prompt logprobs.
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949 # TODO remove this after https://github.com/vllm-project/vllm/pull/13949
# is merged # is merged
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False) return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
def test_n_gt_1(model): def test_n_gt_1(llm):
"""ParallelSampling is supported.""" """ParallelSampling is supported."""
params = SamplingParams(n=3) params = SamplingParams(n=3)
outputs = model.generate(PROMPT, params) outputs = llm.generate(PROMPT, params)
assert len(outputs[0].outputs) == 3 assert len(outputs[0].outputs) == 3
def test_best_of(model): def test_best_of(llm):
"""Raise a ValueError since best_of is deprecated.""" """Raise a ValueError since best_of is deprecated."""
params = SamplingParams(n=2, best_of=3) params = SamplingParams(n=2, best_of=3)
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = model.generate(PROMPT, params) _ = llm.generate(PROMPT, params)
def test_penalties(model): def test_penalties(llm):
"""Check that we do not get errors if applied.""" """Check that we do not get errors if applied."""
params = SamplingParams( params = SamplingParams(
...@@ -49,18 +49,18 @@ def test_penalties(model): ...@@ -49,18 +49,18 @@ def test_penalties(model):
top_p=0.5, top_p=0.5,
top_k=3, top_k=3,
) )
_ = model.generate(PROMPT, params) _ = llm.generate(PROMPT, params)
def test_stop(model): def test_stop(llm):
"""Check that we respect the stop words.""" """Check that we respect the stop words."""
output = model.generate(PROMPT, SamplingParams(temperature=0)) output = llm.generate(PROMPT, SamplingParams(temperature=0))
split_text = output[0].outputs[0].text.split() split_text = output[0].outputs[0].text.split()
STOP_IDX = 5 STOP_IDX = 5
params = SamplingParams(temperature=0, stop=split_text[STOP_IDX]) params = SamplingParams(temperature=0, stop=split_text[STOP_IDX])
output = model.generate(PROMPT, params) output = llm.generate(PROMPT, params)
new_split_text = output[0].outputs[0].text.split() new_split_text = output[0].outputs[0].text.split()
# Output should not contain the stop word. # Output should not contain the stop word.
...@@ -69,40 +69,40 @@ def test_stop(model): ...@@ -69,40 +69,40 @@ def test_stop(model):
params = SamplingParams(temperature=0, params = SamplingParams(temperature=0,
stop=split_text[STOP_IDX], stop=split_text[STOP_IDX],
include_stop_str_in_output=True) include_stop_str_in_output=True)
output = model.generate(PROMPT, params) output = llm.generate(PROMPT, params)
new_split_text = output[0].outputs[0].text.split() new_split_text = output[0].outputs[0].text.split()
# Output should contain the stop word. # Output should contain the stop word.
assert len(new_split_text) == STOP_IDX + 1 assert len(new_split_text) == STOP_IDX + 1
def test_stop_token_ids(model): def test_stop_token_ids(llm):
"""Check that we respect the stop token ids.""" """Check that we respect the stop token ids."""
output = model.generate(PROMPT, SamplingParams(temperature=0)) output = llm.generate(PROMPT, SamplingParams(temperature=0))
stop_token_id_0 = output[0].outputs[0].token_ids[5] stop_token_id_0 = output[0].outputs[0].token_ids[5]
stop_token_id_1 = output[0].outputs[0].token_ids[6] stop_token_id_1 = output[0].outputs[0].token_ids[6]
stop_token_ids = [stop_token_id_1, stop_token_id_0] stop_token_ids = [stop_token_id_1, stop_token_id_0]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
output = model.generate(PROMPT, params) output = llm.generate(PROMPT, params)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
stop_token_ids = [stop_token_id_0, stop_token_id_1] stop_token_ids = [stop_token_id_0, stop_token_id_1]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
output = model.generate(PROMPT, params) output = llm.generate(PROMPT, params)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
def test_detokenize_false(model): def test_detokenize_false(llm):
"""Check that detokenize=False option works.""" """Check that detokenize=False option works."""
output = model.generate(PROMPT, SamplingParams(detokenize=False)) output = llm.generate(PROMPT, SamplingParams(detokenize=False))
assert len(output[0].outputs[0].token_ids) > 0 assert len(output[0].outputs[0].token_ids) > 0
assert len(output[0].outputs[0].text) == 0 assert len(output[0].outputs[0].text) == 0
output = model.generate( output = llm.generate(
PROMPT, SamplingParams(detokenize=False, logprobs=3, PROMPT, SamplingParams(detokenize=False, logprobs=3,
prompt_logprobs=3)) prompt_logprobs=3))
assert len(output[0].outputs[0].token_ids) > 0 assert len(output[0].outputs[0].token_ids) > 0
...@@ -118,28 +118,28 @@ def test_detokenize_false(model): ...@@ -118,28 +118,28 @@ def test_detokenize_false(model):
assert all(lp.decoded_token is None for lp in logprobs.values()) assert all(lp.decoded_token is None for lp in logprobs.values())
def test_bad_words(model): def test_bad_words(llm):
"""Check that we respect bad words.""" """Check that we respect bad words."""
output = model.generate(PROMPT, SamplingParams(temperature=0)) output = llm.generate(PROMPT, SamplingParams(temperature=0))
split_text = output[0].outputs[0].text.split() split_text = output[0].outputs[0].text.split()
bad_words_1 = " ".join(split_text[:2]) bad_words_1 = " ".join(split_text[:2])
params = SamplingParams(temperature=0, bad_words=[bad_words_1]) params = SamplingParams(temperature=0, bad_words=[bad_words_1])
output = model.generate(PROMPT, params) output = llm.generate(PROMPT, params)
new_text = output[0].outputs[0].text new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text assert bad_words_1 not in new_text
bad_words_2 = new_text.split()[-1] bad_words_2 = new_text.split()[-1]
params = SamplingParams(temperature=0, params = SamplingParams(temperature=0,
bad_words=[bad_words_1, bad_words_2]) bad_words=[bad_words_1, bad_words_2])
output = model.generate(PROMPT, params) output = llm.generate(PROMPT, params)
new_text = output[0].outputs[0].text new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text assert bad_words_1 not in new_text
assert bad_words_2 not in new_text assert bad_words_2 not in new_text
def test_logits_processor(model): def test_logits_processor(llm):
"""Check that we reject logits processor.""" """Check that we reject logits processor."""
# This sample logits processor gives infinite score to the i-th token, # This sample logits processor gives infinite score to the i-th token,
...@@ -150,47 +150,45 @@ def test_logits_processor(model): ...@@ -150,47 +150,45 @@ def test_logits_processor(model):
return logits return logits
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = model.generate(PROMPT, _ = llm.generate(PROMPT, SamplingParams(logits_processors=[pick_ith]))
SamplingParams(logits_processors=[pick_ith]))
def test_allowed_token_ids(model): def test_allowed_token_ids(llm):
"""Check that we can use allowed_token_ids.""" """Check that we can use allowed_token_ids."""
TOKEN_ID = 10 TOKEN_ID = 10
allowed_token_ids = [TOKEN_ID] allowed_token_ids = [TOKEN_ID]
output = model.generate( output = llm.generate(PROMPT,
PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids)) SamplingParams(allowed_token_ids=allowed_token_ids))
assert output[0].outputs[0].token_ids[-1] == TOKEN_ID assert output[0].outputs[0].token_ids[-1] == TOKEN_ID
# Reject empty allowed_token_ids. # Reject empty allowed_token_ids.
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[])) _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[]))
# Reject negative token id. # Reject negative token id.
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1])) _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))
# Reject out of vocabulary. # Reject out of vocabulary.
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = model.generate(PROMPT, _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[10000000]))
SamplingParams(allowed_token_ids=[10000000]))
def test_priority(model): def test_priority(llm):
"""Check that we reject requests with priority.""" """Check that we reject requests with priority."""
# Reject all allowed token ids # Reject all allowed token ids
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = model.generate(PROMPT, priority=[1]) _ = llm.generate(PROMPT, priority=[1])
def test_seed(model): def test_seed(llm):
"""Check that seed impacts randomness.""" """Check that seed impacts randomness."""
out_1 = model.generate(PROMPT, SamplingParams(seed=42)) out_1 = llm.generate(PROMPT, SamplingParams(seed=42))
out_2 = model.generate(PROMPT, SamplingParams(seed=42)) out_2 = llm.generate(PROMPT, SamplingParams(seed=42))
out_3 = model.generate(PROMPT, SamplingParams(seed=43)) out_3 = llm.generate(PROMPT, SamplingParams(seed=43))
assert out_1[0].outputs[0].text == out_2[0].outputs[0].text assert out_1[0].outputs[0].text == out_2[0].outputs[0].text
assert out_1[0].outputs[0].text != out_3[0].outputs[0].text assert out_1[0].outputs[0].text != out_3[0].outputs[0].text
...@@ -106,9 +106,9 @@ def test_v1_llm_by_default(monkeypatch): ...@@ -106,9 +106,9 @@ def test_v1_llm_by_default(monkeypatch):
m.delenv("VLLM_USE_V1") m.delenv("VLLM_USE_V1")
# Should default to V1 for supported config. # Should default to V1 for supported config.
model = LLM(MODEL, enforce_eager=True, enable_lora=True) llm = LLM(MODEL, enforce_eager=True, enable_lora=True)
print(model.generate("Hello my name is")) print(llm.generate("Hello my name is"))
assert hasattr(model.llm_engine, "engine_core") assert hasattr(llm.llm_engine, "engine_core")
m.delenv("VLLM_USE_V1") m.delenv("VLLM_USE_V1")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment