"vscode:/vscode.git/clone" did not exist on "9b0187003e62bdb7311b23b5b5026ea8e4e207d3"
Unverified Commit d9784107 authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[Misc] unify variable for LLM instance (#20996)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent e6b90a28
...@@ -14,7 +14,7 @@ For example: ...@@ -14,7 +14,7 @@ For example:
```python ```python
from vllm import LLM from vllm import LLM
model = LLM( llm = LLM(
model="cerebras/Cerebras-GPT-1.3B", model="cerebras/Cerebras-GPT-1.3B",
hf_overrides={"architectures": ["GPT2LMHeadModel"]}, # GPT-2 hf_overrides={"architectures": ["GPT2LMHeadModel"]}, # GPT-2
) )
......
...@@ -302,7 +302,7 @@ To this end, we allow registration of default multimodal LoRAs to handle this au ...@@ -302,7 +302,7 @@ To this end, we allow registration of default multimodal LoRAs to handle this au
return tokenizer.apply_chat_template(chat, tokenize=False) return tokenizer.apply_chat_template(chat, tokenize=False)
model = LLM( llm = LLM(
model=model_id, model=model_id,
enable_lora=True, enable_lora=True,
max_lora_rank=64, max_lora_rank=64,
...@@ -329,7 +329,7 @@ To this end, we allow registration of default multimodal LoRAs to handle this au ...@@ -329,7 +329,7 @@ To this end, we allow registration of default multimodal LoRAs to handle this au
} }
outputs = model.generate( outputs = llm.generate(
inputs, inputs,
sampling_params=SamplingParams( sampling_params=SamplingParams(
temperature=0.2, temperature=0.2,
......
...@@ -86,8 +86,9 @@ Load and run the model in `vllm`: ...@@ -86,8 +86,9 @@ Load and run the model in `vllm`:
```python ```python
from vllm import LLM from vllm import LLM
model = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic")
result = model.generate("Hello my name is") llm = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic")
result = llm.generate("Hello my name is")
print(result[0].outputs[0].text) print(result[0].outputs[0].text)
``` ```
...@@ -125,9 +126,10 @@ In this mode, all Linear modules (except for the final `lm_head`) have their wei ...@@ -125,9 +126,10 @@ In this mode, all Linear modules (except for the final `lm_head`) have their wei
```python ```python
from vllm import LLM from vllm import LLM
model = LLM("facebook/opt-125m", quantization="fp8")
llm = LLM("facebook/opt-125m", quantization="fp8")
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB # INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
result = model.generate("Hello, my name is") result = llm.generate("Hello, my name is")
print(result[0].outputs[0].text) print(result[0].outputs[0].text)
``` ```
......
...@@ -108,7 +108,8 @@ After quantization, you can load and run the model in vLLM: ...@@ -108,7 +108,8 @@ After quantization, you can load and run the model in vLLM:
```python ```python
from vllm import LLM from vllm import LLM
model = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128")
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128")
``` ```
To evaluate accuracy, you can use `lm_eval`: To evaluate accuracy, you can use `lm_eval`:
......
...@@ -114,7 +114,8 @@ After quantization, you can load and run the model in vLLM: ...@@ -114,7 +114,8 @@ After quantization, you can load and run the model in vLLM:
```python ```python
from vllm import LLM from vllm import LLM
model = LLM("./Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token")
llm = LLM("./Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token")
``` ```
To evaluate accuracy, you can use `lm_eval`: To evaluate accuracy, you can use `lm_eval`:
......
...@@ -174,11 +174,11 @@ You can change the output dimensions of embedding models that support Matryoshka ...@@ -174,11 +174,11 @@ You can change the output dimensions of embedding models that support Matryoshka
```python ```python
from vllm import LLM, PoolingParams from vllm import LLM, PoolingParams
model = LLM(model="jinaai/jina-embeddings-v3", llm = LLM(model="jinaai/jina-embeddings-v3",
task="embed", task="embed",
trust_remote_code=True) trust_remote_code=True)
outputs = model.embed(["Follow the white rabbit."], outputs = llm.embed(["Follow the white rabbit."],
pooling_params=PoolingParams(dimensions=32)) pooling_params=PoolingParams(dimensions=32))
print(outputs[0].outputs) print(outputs[0].outputs)
``` ```
......
...@@ -28,10 +28,10 @@ def main(args: Namespace): ...@@ -28,10 +28,10 @@ def main(args: Namespace):
# Create an LLM. # Create an LLM.
# You should pass task="classify" for classification models # You should pass task="classify" for classification models
model = LLM(**vars(args)) llm = LLM(**vars(args))
# Generate logits. The output is a list of ClassificationRequestOutputs. # Generate logits. The output is a list of ClassificationRequestOutputs.
outputs = model.classify(prompts) outputs = llm.classify(prompts)
# Print the outputs. # Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60) print("\nGenerated Outputs:\n" + "-" * 60)
......
...@@ -31,10 +31,10 @@ def main(args: Namespace): ...@@ -31,10 +31,10 @@ def main(args: Namespace):
# Create an LLM. # Create an LLM.
# You should pass task="embed" for embedding models # You should pass task="embed" for embedding models
model = LLM(**vars(args)) llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs. # Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.embed(prompts) outputs = llm.embed(prompts)
# Print the outputs. # Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60) print("\nGenerated Outputs:\n" + "-" * 60)
......
...@@ -27,10 +27,10 @@ def main(args: Namespace): ...@@ -27,10 +27,10 @@ def main(args: Namespace):
# Create an LLM. # Create an LLM.
# You should pass task="score" for cross-encoder models # You should pass task="score" for cross-encoder models
model = LLM(**vars(args)) llm = LLM(**vars(args))
# Generate scores. The output is a list of ScoringRequestOutputs. # Generate scores. The output is a list of ScoringRequestOutputs.
outputs = model.score(text_1, texts_2) outputs = llm.score(text_1, texts_2)
# Print the outputs. # Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60) print("\nGenerated Outputs:\n" + "-" * 60)
......
...@@ -30,11 +30,11 @@ def main(args: Namespace): ...@@ -30,11 +30,11 @@ def main(args: Namespace):
# Create an LLM. # Create an LLM.
# You should pass task="embed" for embedding models # You should pass task="embed" for embedding models
model = LLM(**vars(args)) llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs. # Generate embedding. The output is a list of EmbeddingRequestOutputs.
# Only text matching task is supported for now. See #16120 # Only text matching task is supported for now. See #16120
outputs = model.embed(prompts) outputs = llm.embed(prompts)
# Print the outputs. # Print the outputs.
print("\nGenerated Outputs:") print("\nGenerated Outputs:")
......
...@@ -30,10 +30,10 @@ def main(args: Namespace): ...@@ -30,10 +30,10 @@ def main(args: Namespace):
# Create an LLM. # Create an LLM.
# You should pass task="embed" for embedding models # You should pass task="embed" for embedding models
model = LLM(**vars(args)) llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs. # Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.embed(prompts, pooling_params=PoolingParams(dimensions=32)) outputs = llm.embed(prompts, pooling_params=PoolingParams(dimensions=32))
# Print the outputs. # Print the outputs.
print("\nGenerated Outputs:") print("\nGenerated Outputs:")
......
...@@ -25,7 +25,7 @@ def config_buckets(): ...@@ -25,7 +25,7 @@ def config_buckets():
os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
def initialize_model(): def initialize_llm():
"""Create an LLM with speculative decoding.""" """Create an LLM with speculative decoding."""
return LLM( return LLM(
model="openlm-research/open_llama_7b", model="openlm-research/open_llama_7b",
...@@ -43,9 +43,9 @@ def initialize_model(): ...@@ -43,9 +43,9 @@ def initialize_model():
) )
def process_requests(model: LLM, sampling_params: SamplingParams): def process_requests(llm: LLM, sampling_params: SamplingParams):
"""Generate texts from prompts and print them.""" """Generate texts from prompts and print them."""
outputs = model.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
...@@ -53,12 +53,12 @@ def process_requests(model: LLM, sampling_params: SamplingParams): ...@@ -53,12 +53,12 @@ def process_requests(model: LLM, sampling_params: SamplingParams):
def main(): def main():
"""Main function that sets up the model and processes prompts.""" """Main function that sets up the llm and processes prompts."""
config_buckets() config_buckets()
model = initialize_model() llm = initialize_llm()
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, top_k=1) sampling_params = SamplingParams(max_tokens=100, top_k=1)
process_requests(model, sampling_params) process_requests(llm, sampling_params)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -140,7 +140,7 @@ datamodule_config = { ...@@ -140,7 +140,7 @@ datamodule_config = {
class PrithviMAE: class PrithviMAE:
def __init__(self): def __init__(self):
print("Initializing PrithviMAE model") print("Initializing PrithviMAE model")
self.model = LLM( self.llm = LLM(
model=os.path.join(os.path.dirname(__file__), "./model"), model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True, skip_tokenizer_init=True,
dtype="float32", dtype="float32",
...@@ -158,7 +158,7 @@ class PrithviMAE: ...@@ -158,7 +158,7 @@ class PrithviMAE:
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False) outputs = self.llm.encode(prompt, use_tqdm=False)
print("################ Inference done (it took seconds) ##############") print("################ Inference done (it took seconds) ##############")
return outputs[0].outputs.data return outputs[0].outputs.data
......
...@@ -17,13 +17,13 @@ model_name = "Qwen/Qwen3-Reranker-0.6B" ...@@ -17,13 +17,13 @@ model_name = "Qwen/Qwen3-Reranker-0.6B"
# Models converted offline using this method can not only be more efficient # Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more # and support the vllm score API, but also make the init parameters more
# concise, for example. # concise, for example.
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score") # llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")
# If you want to load the official original version, the init parameters are # If you want to load the official original version, the init parameters are
# as follows. # as follows.
def get_model() -> LLM: def get_llm() -> LLM:
"""Initializes and returns the LLM model for Qwen3-Reranker.""" """Initializes and returns the LLM model for Qwen3-Reranker."""
return LLM( return LLM(
model=model_name, model=model_name,
...@@ -77,8 +77,8 @@ def main() -> None: ...@@ -77,8 +77,8 @@ def main() -> None:
] ]
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
model = get_model() llm = get_llm()
outputs = model.score(queries, documents) outputs = llm.score(queries, documents)
print("-" * 30) print("-" * 30)
print([output.outputs.score for output in outputs]) print([output.outputs.score for output in outputs])
......
...@@ -236,13 +236,13 @@ def test_failed_model_execution(vllm_runner, monkeypatch) -> None: ...@@ -236,13 +236,13 @@ def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
if isinstance(vllm_model.model.llm_engine, LLMEngineV1): if isinstance(vllm_model.llm.llm_engine, LLMEngineV1):
v1_test_failed_model_execution(vllm_model) v1_test_failed_model_execution(vllm_model)
def v1_test_failed_model_execution(vllm_model): def v1_test_failed_model_execution(vllm_model):
engine = vllm_model.model.llm_engine engine = vllm_model.llm.llm_engine
mocked_execute_model = Mock( mocked_execute_model = Mock(
side_effect=RuntimeError("Mocked Critical Error")) side_effect=RuntimeError("Mocked Critical Error"))
engine.engine_core.engine_core.model_executor.execute_model =\ engine.engine_core.engine_core.model_executor.execute_model =\
......
...@@ -81,7 +81,7 @@ def test_chunked_prefill_recompute( ...@@ -81,7 +81,7 @@ def test_chunked_prefill_recompute(
disable_log_stats=False, disable_log_stats=False,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
...@@ -118,10 +118,10 @@ def test_preemption( ...@@ -118,10 +118,10 @@ def test_preemption(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = ( total_preemption = (
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption) vllm_model.llm.llm_engine.scheduler[0].num_cumulative_preemption)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
...@@ -174,12 +174,12 @@ def test_preemption_infeasible( ...@@ -174,12 +174,12 @@ def test_preemption_infeasible(
) as vllm_model: ) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens, sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True) ignore_eos=True)
req_outputs = vllm_model.model.generate( req_outputs = vllm_model.llm.generate(
example_prompts, example_prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
# Verify the request is ignored and not hang. # Verify the request is ignored and not hang.
......
...@@ -784,7 +784,7 @@ class VllmRunner: ...@@ -784,7 +784,7 @@ class VllmRunner:
enforce_eager: Optional[bool] = False, enforce_eager: Optional[bool] = False,
**kwargs, **kwargs,
) -> None: ) -> None:
self.model = LLM( self.llm = LLM(
model=model_name, model=model_name,
task=task, task=task,
tokenizer=tokenizer_name, tokenizer=tokenizer_name,
...@@ -854,9 +854,9 @@ class VllmRunner: ...@@ -854,9 +854,9 @@ class VllmRunner:
videos=videos, videos=videos,
audios=audios) audios=audios)
req_outputs = self.model.generate(inputs, req_outputs = self.llm.generate(inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
**kwargs) **kwargs)
outputs: list[tuple[list[list[int]], list[str]]] = [] outputs: list[tuple[list[list[int]], list[str]]] = []
for req_output in req_outputs: for req_output in req_outputs:
...@@ -902,9 +902,9 @@ class VllmRunner: ...@@ -902,9 +902,9 @@ class VllmRunner:
videos=videos, videos=videos,
audios=audios) audios=audios)
req_outputs = self.model.generate(inputs, req_outputs = self.llm.generate(inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
**kwargs) **kwargs)
toks_str_logsprobs_prompt_logprobs = ( toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs)) self._final_steps_generate_w_logprobs(req_outputs))
...@@ -924,8 +924,8 @@ class VllmRunner: ...@@ -924,8 +924,8 @@ class VllmRunner:
''' '''
assert sampling_params.logprobs is not None assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts, req_outputs = self.llm.generate(encoder_decoder_prompts,
sampling_params=sampling_params) sampling_params=sampling_params)
toks_str_logsprobs_prompt_logprobs = ( toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs)) self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params # Omit prompt logprobs if not required by sampling params
...@@ -1018,7 +1018,7 @@ class VllmRunner: ...@@ -1018,7 +1018,7 @@ class VllmRunner:
videos=videos, videos=videos,
audios=audios) audios=audios)
outputs = self.model.beam_search( outputs = self.llm.beam_search(
inputs, inputs,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
returned_outputs = [] returned_outputs = []
...@@ -1029,7 +1029,7 @@ class VllmRunner: ...@@ -1029,7 +1029,7 @@ class VllmRunner:
return returned_outputs return returned_outputs
def classify(self, prompts: list[str]) -> list[list[float]]: def classify(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.model.classify(prompts) req_outputs = self.llm.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs] return [req_output.outputs.probs for req_output in req_outputs]
def embed(self, def embed(self,
...@@ -1044,11 +1044,11 @@ class VllmRunner: ...@@ -1044,11 +1044,11 @@ class VllmRunner:
videos=videos, videos=videos,
audios=audios) audios=audios)
req_outputs = self.model.embed(inputs, *args, **kwargs) req_outputs = self.llm.embed(inputs, *args, **kwargs)
return [req_output.outputs.embedding for req_output in req_outputs] return [req_output.outputs.embedding for req_output in req_outputs]
def encode(self, prompts: list[str]) -> list[list[float]]: def encode(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.model.encode(prompts) req_outputs = self.llm.encode(prompts)
return [req_output.outputs.data for req_output in req_outputs] return [req_output.outputs.data for req_output in req_outputs]
def score( def score(
...@@ -1058,18 +1058,18 @@ class VllmRunner: ...@@ -1058,18 +1058,18 @@ class VllmRunner:
*args, *args,
**kwargs, **kwargs,
) -> list[float]: ) -> list[float]:
req_outputs = self.model.score(text_1, text_2, *args, **kwargs) req_outputs = self.llm.score(text_1, text_2, *args, **kwargs)
return [req_output.outputs.score for req_output in req_outputs] return [req_output.outputs.score for req_output in req_outputs]
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
executor = self.model.llm_engine.model_executor executor = self.llm.llm_engine.model_executor
return executor.apply_model(func) return executor.apply_model(func)
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
del self.model del self.llm
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
......
...@@ -37,7 +37,7 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, ...@@ -37,7 +37,7 @@ def test_num_computed_tokens_update(num_scheduler_steps: int,
num_scheduler_steps=num_scheduler_steps, num_scheduler_steps=num_scheduler_steps,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
enforce_eager=enforce_eager) enforce_eager=enforce_eager)
engine: LLMEngine = runner.model.llm_engine engine: LLMEngine = runner.llm.llm_engine
# In multi-step + chunked-prefill there is no separate single prompt step. # In multi-step + chunked-prefill there is no separate single prompt step.
# What is scheduled will run for num_scheduler_steps always. # What is scheduled will run for num_scheduler_steps always.
......
...@@ -28,7 +28,7 @@ def vllm_model(vllm_runner): ...@@ -28,7 +28,7 @@ def vllm_model(vllm_runner):
def test_stop_reason(vllm_model, example_prompts): def test_stop_reason(vllm_model, example_prompts):
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL) tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR) stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
llm = vllm_model.model llm = vllm_model.llm
# test stop token # test stop token
outputs = llm.generate(example_prompts, outputs = llm.generate(example_prompts,
......
...@@ -101,42 +101,42 @@ def _stop_token_id(llm): ...@@ -101,42 +101,42 @@ def _stop_token_id(llm):
def test_stop_strings(): def test_stop_strings():
# If V0, must set enforce_eager=False since we use # If V0, must set enforce_eager=False since we use
# async output processing below. # async output processing below.
vllm_model = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1) llm = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1)
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
_stop_basic(vllm_model) _stop_basic(llm)
else: else:
_set_async_mode(vllm_model, True) _set_async_mode(llm, True)
_stop_basic(vllm_model) _stop_basic(llm)
_set_async_mode(vllm_model, False) _set_async_mode(llm, False)
_stop_basic(vllm_model) _stop_basic(llm)
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
_stop_multi_tokens(vllm_model) _stop_multi_tokens(llm)
else: else:
_set_async_mode(vllm_model, True) _set_async_mode(llm, True)
_stop_multi_tokens(vllm_model) _stop_multi_tokens(llm)
_set_async_mode(vllm_model, False) _set_async_mode(llm, False)
_stop_multi_tokens(vllm_model) _stop_multi_tokens(llm)
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
_stop_partial_token(vllm_model) _stop_partial_token(llm)
else: else:
_set_async_mode(vllm_model, True) _set_async_mode(llm, True)
_stop_partial_token(vllm_model) _stop_partial_token(llm)
_set_async_mode(vllm_model, False) _set_async_mode(llm, False)
_stop_partial_token(vllm_model) _stop_partial_token(llm)
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
# FIXME: this does not respect include_in_output=False # FIXME: this does not respect include_in_output=False
# _stop_token_id(vllm_model) # _stop_token_id(llm)
pass pass
else: else:
_set_async_mode(vllm_model, True) _set_async_mode(llm, True)
_stop_token_id(vllm_model) _stop_token_id(llm)
_set_async_mode(vllm_model, False) _set_async_mode(llm, False)
_stop_token_id(vllm_model) _stop_token_id(llm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment