Unverified Commit 28e1299e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

rename PromptInputs and inputs with backward compatibility (#8760)

parent 0c4d2ad5
......@@ -9,8 +9,8 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
SingletonPrompt)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING:
......@@ -206,7 +206,7 @@ class InputPreprocessor:
def _extract_prompt_components(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
......@@ -216,7 +216,7 @@ class InputPreprocessor:
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
......@@ -226,24 +226,24 @@ class InputPreprocessor:
* multi_modal_data
'''
parsed = parse_singleton_prompt(inputs)
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt = None
prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_text = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
......@@ -251,33 +251,33 @@ class InputPreprocessor:
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
return prompt_text, prompt_token_ids, multi_modal_data
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(inputs)
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt = None
prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_text = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
......@@ -285,7 +285,7 @@ class InputPreprocessor:
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
return prompt_text, prompt_token_ids, multi_modal_data
def _build_enc_dec_llm_inputs(
self,
......@@ -311,7 +311,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
......@@ -339,7 +339,7 @@ class InputPreprocessor:
Arguments:
* inputs: an input prompt
* prompt: an input prompt
* request_id
Returns:
......@@ -350,13 +350,13 @@ class InputPreprocessor:
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
......@@ -365,7 +365,7 @@ class InputPreprocessor:
)
else:
encoder_comps = self._extract_prompt_components(
inputs,
prompt,
request_id=request_id,
)
......@@ -375,20 +375,20 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
......@@ -401,7 +401,7 @@ class InputPreprocessor:
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
prompt,
request_id=request_id,
)
......@@ -425,7 +425,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
......@@ -436,7 +436,7 @@ class InputPreprocessor:
Arguments:
* inputs: input prompt
* prompt: input prompt
* request_id
* lora_request
* prompt_adapter_request
......@@ -447,7 +447,7 @@ class InputPreprocessor:
'''
prompt_comps = self._extract_prompt_components(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
)
......@@ -459,14 +459,14 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
)
......@@ -478,7 +478,7 @@ class InputPreprocessor:
def preprocess(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
......@@ -488,17 +488,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
inputs,
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return self._process_decoder_only_prompt(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
......@@ -506,7 +506,7 @@ class InputPreprocessor:
async def preprocess_async(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
......@@ -516,17 +516,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(
inputs,
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment