Unverified Commit 3b19e39d authored by nunjunj's avatar nunjunj Committed by GitHub
Browse files
parent 4cd7d47f
...@@ -147,6 +147,7 @@ steps: ...@@ -147,6 +147,7 @@ steps:
- pip install awscli tensorizer # for llava example and tensorizer test - pip install awscli tensorizer # for llava example and tensorizer test
- python3 offline_inference.py - python3 offline_inference.py
- python3 cpu_offload.py - python3 cpu_offload.py
- python3 offline_inference_chat.py
- python3 offline_inference_with_prefix.py - python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py - python3 llm_engine_example.py
- python3 offline_inference_vision_language.py - python3 offline_inference_vision_language.py
......
from vllm import LLM, SamplingParams
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
sampling_params = SamplingParams(temperature=0.5)
def print_outputs(outputs):
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("-" * 80)
print("=" * 80)
# In this script, we demonstrate how to pass input to the chat method:
conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation,
sampling_params=sampling_params,
use_tqdm=False)
print_outputs(outputs)
# A chat template can be optionally supplied.
# If not, the model will use its default chat template.
# with open('template_falcon_180b.jinja', "r") as f:
# chat_template = f.read()
# outputs = llm.chat(
# conversations,
# sampling_params=sampling_params,
# use_tqdm=False,
# chat_template=chat_template,
# )
...@@ -140,3 +140,22 @@ def test_multiple_sampling_params(llm: LLM): ...@@ -140,3 +140,22 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied # sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None) outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs) assert len(PROMPTS) == len(outputs)
def test_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
...@@ -6,6 +6,9 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast ...@@ -6,6 +6,9 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -87,7 +90,7 @@ class LLM: ...@@ -87,7 +90,7 @@ class LLM:
disable_custom_all_reduce: See ParallelConfig disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`) :ref:`engine_args`)
Note: Note:
This class is intended to be used for offline inference. For online This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead. serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
...@@ -138,8 +141,12 @@ class LLM: ...@@ -138,8 +141,12 @@ class LLM:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size", removed_vision_keys = (
"image_input_shape", "image_input_type") "image_token_id",
"image_feature_size",
"image_input_shape",
"image_input_type",
)
if any(k in kwargs for k in removed_vision_keys): if any(k in kwargs for k in removed_vision_keys):
raise TypeError( raise TypeError(
"There is no need to pass vision-related arguments anymore.") "There is no need to pass vision-related arguments anymore.")
...@@ -259,11 +266,12 @@ class LLM: ...@@ -259,11 +266,12 @@ class LLM:
) -> List[RequestOutput]: ) -> List[RequestOutput]:
... ...
@deprecate_kwargs("prompts", @deprecate_kwargs(
"prompt_token_ids", "prompts",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, "prompt_token_ids",
additional_message="Please use the 'inputs' parameter " is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
"instead.") additional_message="Please use the 'inputs' parameter instead.",
)
def generate( def generate(
self, self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
...@@ -286,17 +294,17 @@ class LLM: ...@@ -286,17 +294,17 @@ class LLM:
Args: Args:
inputs: A list of inputs to generate completions for. inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt. When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt. prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for prompt_adapter_request: Prompt Adapter request to use for
generation, if any. generation, if any.
Returns: Returns:
A list of `RequestOutput` objects containing the A list of ``RequestOutput`` objects containing the
generated completions in the same order as the input prompts. generated completions in the same order as the input prompts.
Note: Note:
...@@ -339,6 +347,62 @@ class LLM: ...@@ -339,6 +347,62 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput) return LLMEngine.validate_outputs(outputs, RequestOutput)
def chat(
self,
messages: List[ChatCompletionMessageParam],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_template: bool = True,
) -> List[RequestOutput]:
"""
Generates responses for chat messages.
Converts the messages to prompts using the tokenizer and calls
the :meth:`generate` method to generate the responses.
Args:
messages: A list of messages to generate responses for. Each
message is a list of dictionaries with 'role' and 'content'
keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
add_generation_template: If True, adds a generation template
to each message.
Returns:
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)
prompts = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_template=add_generation_template)
return self.generate(
prompts,
sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
@overload # LEGACY: single (prompt + optional token ids) @overload # LEGACY: single (prompt + optional token ids)
def encode( def encode(
self, self,
...@@ -413,11 +477,12 @@ class LLM: ...@@ -413,11 +477,12 @@ class LLM:
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
... ...
@deprecate_kwargs("prompts", @deprecate_kwargs(
"prompt_token_ids", "prompts",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, "prompt_token_ids",
additional_message="Please use the 'inputs' parameter " is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
"instead.") additional_message="Please use the 'inputs' parameter instead.",
)
def encode( def encode(
self, self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
...@@ -443,7 +508,7 @@ class LLM: ...@@ -443,7 +508,7 @@ class LLM:
use the default pooling parameters. use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for prompt_adapter_request: Prompt Adapter request to use for
generation, if any. generation, if any.
Returns: Returns:
...@@ -563,15 +628,15 @@ class LLM: ...@@ -563,15 +628,15 @@ class LLM:
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
)
def _add_request( def _add_request(
self, self,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request( self.llm_engine.add_request(
...@@ -579,7 +644,8 @@ class LLM: ...@@ -579,7 +644,8 @@ class LLM:
inputs, inputs,
params, params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
)
def _add_guided_processor( def _add_guided_processor(
self, self,
...@@ -628,8 +694,8 @@ class LLM: ...@@ -628,8 +694,8 @@ class LLM:
in_spd = total_in_toks / pbar.format_dict["elapsed"] in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum( total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs) len(stp.token_ids) for stp in output.outputs)
out_spd = total_out_toks / pbar.format_dict[ out_spd = (total_out_toks /
"elapsed"] pbar.format_dict["elapsed"])
pbar.postfix = ( pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, " f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s") f"output: {out_spd:.2f} toks/s")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment