Unverified Commit da6ea29f authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[V1] Avoid redundant input processing in n>1 case (#14985)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 7297941b
...@@ -24,12 +24,10 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): ...@@ -24,12 +24,10 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
) )
lora_request = LoRARequest("1", 1, sql_lora_files) lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request) prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode( assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async( "prompt") == await tokenizer_group.encode_async(
request_id="request_id", prompt="prompt", lora_request=lora_request)
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer_group.get_lora_tokenizer(None), assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer( assert tokenizer_group.get_lora_tokenizer(
......
...@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type): ...@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
max_input_length=None, max_input_length=None,
) )
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=None) prompt="prompt", lora_request=None)
assert reference_tokenizer.encode( assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async( "prompt") == await tokenizer_group.encode_async(prompt="prompt",
request_id="request_id", prompt="prompt", lora_request=None) lora_request=None)
assert isinstance(tokenizer_group.get_lora_tokenizer(None), assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer( assert tokenizer_group.get_lora_tokenizer(
...@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type): ...@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
# and check that all requests are processed correctly. # and check that all requests are processed correctly.
num_requests = tokenizer_group_pool.pool_size * 5 num_requests = tokenizer_group_pool.pool_size * 5
requests = [ requests = [
tokenizer_group_pool.encode_async(request_id=str(i), tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
prompt=f"prompt {i}",
lora_request=None) lora_request=None)
for i in range(num_requests) for i in range(num_requests)
] ]
...@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): ...@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
fail_at[0] = 1000 fail_at[0] = 1000
# We should recover successfully. # We should recover successfully.
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
prompt="prompt", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Check that we have a new actor # Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors) assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
...@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): ...@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# We should fail after re-initialization. # We should fail after re-initialization.
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt",
prompt="prompt",
lora_request=None) lora_request=None)
# check_health should raise the same thing # check_health should raise the same thing
...@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): ...@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# Prompt too long error # Prompt too long error
with pytest.raises(ValueError): with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
prompt="prompt" * 100,
lora_request=None) lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
prompt="prompt",
lora_request=None)
# Actors should stay the same. # Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
...@@ -492,7 +492,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -492,7 +492,6 @@ class _AsyncLLMEngine(LLMEngine):
preprocessed_inputs = await self.input_preprocessor.preprocess_async( preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
......
...@@ -783,7 +783,6 @@ class LLMEngine: ...@@ -783,7 +783,6 @@ class LLMEngine:
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
......
...@@ -81,10 +81,7 @@ class EngineClient(ABC): ...@@ -81,10 +81,7 @@ class EngineClient(ABC):
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError raise NotImplementedError
else: else:
processed_inputs = preprocessor._prompt_to_llm_inputs( processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
prompt,
request_id=request_id,
)
prompt_token_ids = processed_inputs["prompt_token_ids"] prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt") prompt_text = processed_inputs.get("prompt")
......
...@@ -182,7 +182,6 @@ class InputPreprocessor: ...@@ -182,7 +182,6 @@ class InputPreprocessor:
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
request_id: str,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> list[int]: ) -> list[int]:
""" """
...@@ -202,15 +201,13 @@ class InputPreprocessor: ...@@ -202,15 +201,13 @@ class InputPreprocessor:
"do_lower_case", False)): "do_lower_case", False)):
prompt = prompt.lower() prompt = prompt.lower()
return tokenizer.encode(request_id=request_id, return tokenizer.encode(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
async def _tokenize_prompt_async( async def _tokenize_prompt_async(
self, self,
prompt: str, prompt: str,
request_id: str,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> list[int]: ) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`.""" """Async version of :meth:`_tokenize_prompt`."""
...@@ -222,7 +219,6 @@ class InputPreprocessor: ...@@ -222,7 +219,6 @@ class InputPreprocessor:
# appending an EOS token to the prompt which disrupts generation. # appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False add_special_tokens = False
return await tokenizer.encode_async( return await tokenizer.encode_async(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
...@@ -309,7 +305,6 @@ class InputPreprocessor: ...@@ -309,7 +305,6 @@ class InputPreprocessor:
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> SingletonInputs: ) -> SingletonInputs:
...@@ -318,7 +313,6 @@ class InputPreprocessor: ...@@ -318,7 +313,6 @@ class InputPreprocessor:
Arguments: Arguments:
* request_id
* prompt: single encoder or decoder input prompt * prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts * lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes * return_mm_hashes: whether to return multimodal hashes
...@@ -333,7 +327,6 @@ class InputPreprocessor: ...@@ -333,7 +327,6 @@ class InputPreprocessor:
prompt_text = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -384,7 +377,6 @@ class InputPreprocessor: ...@@ -384,7 +377,6 @@ class InputPreprocessor:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -400,7 +392,6 @@ class InputPreprocessor: ...@@ -400,7 +392,6 @@ class InputPreprocessor:
async def _prompt_to_llm_inputs_async( async def _prompt_to_llm_inputs_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> SingletonInputs: ) -> SingletonInputs:
...@@ -411,7 +402,6 @@ class InputPreprocessor: ...@@ -411,7 +402,6 @@ class InputPreprocessor:
prompt_text = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -460,7 +450,6 @@ class InputPreprocessor: ...@@ -460,7 +450,6 @@ class InputPreprocessor:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -560,7 +549,6 @@ class InputPreprocessor: ...@@ -560,7 +549,6 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
...@@ -587,7 +575,6 @@ class InputPreprocessor: ...@@ -587,7 +575,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: an input prompt * prompt: an input prompt
* request_id
Returns: Returns:
...@@ -598,16 +585,11 @@ class InputPreprocessor: ...@@ -598,16 +585,11 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
encoder_inputs = self._prompt_to_llm_inputs( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"])
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
else: else:
decoder_inputs = self._prompt_to_llm_inputs( decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
decoder_input,
request_id=request_id,
)
# For multimodal model, override decoder prompt from processor # For multimodal model, override decoder prompt from processor
# with explicit decoder prompt. # with explicit decoder prompt.
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
...@@ -616,10 +598,7 @@ class InputPreprocessor: ...@@ -616,10 +598,7 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs)) encoder_inputs, decoder_inputs))
else: else:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(prompt)
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
self._can_process_multimodal()): self._can_process_multimodal()):
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -636,7 +615,6 @@ class InputPreprocessor: ...@@ -636,7 +615,6 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`.""" """Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_inputs: SingletonInputs encoder_inputs: SingletonInputs
...@@ -644,18 +622,13 @@ class InputPreprocessor: ...@@ -644,18 +622,13 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._prompt_to_llm_inputs_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"])
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_inputs = await encoder_task encoder_inputs = await encoder_task
decoder_inputs = None decoder_inputs = None
else: else:
decoder_task = self._prompt_to_llm_inputs_async( decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
decoder_input,
request_id=request_id,
)
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task) encoder_task, decoder_task)
...@@ -668,10 +641,7 @@ class InputPreprocessor: ...@@ -668,10 +641,7 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs)) encoder_inputs, decoder_inputs))
else: else:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(prompt)
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
self._can_process_multimodal()): self._can_process_multimodal()):
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -704,7 +674,6 @@ class InputPreprocessor: ...@@ -704,7 +674,6 @@ class InputPreprocessor:
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
...@@ -716,7 +685,6 @@ class InputPreprocessor: ...@@ -716,7 +685,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: input prompt * prompt: input prompt
* request_id
* lora_request * lora_request
* prompt_adapter_request * prompt_adapter_request
* return_mm_hashes * return_mm_hashes
...@@ -728,7 +696,6 @@ class InputPreprocessor: ...@@ -728,7 +696,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs( prompt_comps = self._prompt_to_llm_inputs(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
...@@ -741,7 +708,6 @@ class InputPreprocessor: ...@@ -741,7 +708,6 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async( async def _process_decoder_only_prompt_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
...@@ -749,7 +715,6 @@ class InputPreprocessor: ...@@ -749,7 +715,6 @@ class InputPreprocessor:
"""Async version of :meth:`_process_decoder_only_prompt`.""" """Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._prompt_to_llm_inputs_async( prompt_comps = await self._prompt_to_llm_inputs_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
...@@ -762,7 +727,6 @@ class InputPreprocessor: ...@@ -762,7 +727,6 @@ class InputPreprocessor:
def preprocess( def preprocess(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
...@@ -774,10 +738,7 @@ class InputPreprocessor: ...@@ -774,10 +738,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.") "returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(prompt)
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
...@@ -786,7 +747,6 @@ class InputPreprocessor: ...@@ -786,7 +747,6 @@ class InputPreprocessor:
# Decoder-only operation # Decoder-only operation
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
...@@ -795,7 +755,6 @@ class InputPreprocessor: ...@@ -795,7 +755,6 @@ class InputPreprocessor:
async def preprocess_async( async def preprocess_async(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
...@@ -807,10 +766,7 @@ class InputPreprocessor: ...@@ -807,10 +766,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.") "returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(prompt)
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
...@@ -819,7 +775,6 @@ class InputPreprocessor: ...@@ -819,7 +775,6 @@ class InputPreprocessor:
# Decoder-only operation # Decoder-only operation
return await self._process_decoder_only_prompt_async( return await self._process_decoder_only_prompt_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
......
...@@ -33,7 +33,6 @@ class BaseTokenizerGroup(ABC): ...@@ -33,7 +33,6 @@ class BaseTokenizerGroup(ABC):
@abstractmethod @abstractmethod
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
...@@ -43,7 +42,6 @@ class BaseTokenizerGroup(ABC): ...@@ -43,7 +42,6 @@ class BaseTokenizerGroup(ABC):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
......
...@@ -113,7 +113,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -113,7 +113,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group. """Encode a prompt using the tokenizer group.
...@@ -133,8 +132,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -133,8 +132,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor = actor original_actor = actor
try: try:
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens)) add_special_tokens=add_special_tokens))
except ActorDiedError as e: except ActorDiedError as e:
...@@ -145,8 +143,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -145,8 +143,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor = self._init_actor() actor = self._init_actor()
try: try:
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens)) add_special_tokens=add_special_tokens))
except ActorDiedError as e: except ActorDiedError as e:
...@@ -164,7 +161,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -164,7 +161,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group. """Encode a prompt using the tokenizer group.
...@@ -184,7 +180,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -184,7 +180,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor = actor original_actor = actor
try: try:
ret = await actor.encode.remote( ret = await actor.encode.remote(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
...@@ -196,7 +191,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): ...@@ -196,7 +191,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor = self._init_actor() actor = self._init_actor()
try: try:
ret = await actor.encode.remote( ret = await actor.encode.remote(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
......
...@@ -56,7 +56,6 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -56,7 +56,6 @@ class TokenizerGroup(BaseTokenizerGroup):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request) tokenizer = self.get_lora_tokenizer(lora_request)
...@@ -69,7 +68,6 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -69,7 +68,6 @@ class TokenizerGroup(BaseTokenizerGroup):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request) tokenizer = await self.get_lora_tokenizer_async(lora_request)
......
...@@ -4,6 +4,7 @@ import asyncio ...@@ -4,6 +4,7 @@ import asyncio
import logging import logging
import os import os
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
...@@ -25,6 +26,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer ...@@ -25,6 +26,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
...@@ -177,33 +179,44 @@ class AsyncLLM(EngineClient): ...@@ -177,33 +179,44 @@ class AsyncLLM(EngineClient):
) -> asyncio.Queue[RequestOutput]: ) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
# 1) Create a new output queue for the request. # Create a new output queue for the request.
queue: asyncio.Queue[RequestOutput] = asyncio.Queue() queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
# 2) Fan out child requests (for n>1) # Convert Input --> Request.
parent_req = ParentRequest.from_params(request_id, params) request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
# 3) Convert Input --> Request. if n == 1:
request = self.processor.process_inputs(request_id, prompt, params, await self._add_request(request, None, 0, queue)
arrival_time, lora_request, return queue
trace_headers,
prompt_adapter_request,
priority)
# 4) Add the request to OutputProcessor (this process). # Fan out child requests (for n>1).
self.output_processor.add_request(request, parent_req, idx, queue) parent_request = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, parent_request, idx, queue)
return queue
# 5) Add the EngineCoreRequest to EngineCore (separate process). async def _add_request(self, request: EngineCoreRequest,
await self.engine_core.add_request_async(request) parent_req: Optional[ParentRequest], index: int,
queue: asyncio.Queue[RequestOutput]):
if self.log_requests: # Add the request to OutputProcessor (this process).
logger.info("Added request %s.", request_id) self.output_processor.add_request(request, parent_req, index, queue)
return queue # Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
# TODO: we should support multiple prompts in one call, as you # TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion # can do with LLM.generate. So that for multi-prompt completion
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Mapping from collections.abc import Mapping
from copy import copy
from typing import Optional, Union from typing import Optional, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -179,25 +180,34 @@ class LLMEngine: ...@@ -179,25 +180,34 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
# 1) Fan out child requests (for n>1) # Process raw inputs into the request.
parent_req = ParentRequest.from_params(request_id, params) request = self.processor.process_inputs(request_id, prompt, params,
n = params.n if isinstance(params, SamplingParams) else 1 arrival_time, lora_request,
for idx in range(n): trace_headers,
if parent_req is not None: prompt_adapter_request,
request_id, params = parent_req.get_child_info(idx) priority)
# 2) Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 3) Make a new RequestState and queue. n = params.n if isinstance(params, SamplingParams) else 1
self.output_processor.add_request(request, parent_req, idx)
# 3) Add the request to EngineCore. if n == 1:
# Make a new RequestState and queue.
self.output_processor.add_request(request, None, 0)
# Add the request to EngineCore.
self.engine_core.add_request(request) self.engine_core.add_request(request)
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]: def step(self) -> list[RequestOutput]:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from copy import copy from copy import copy
from typing import Optional, Union from typing import Optional
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
...@@ -43,16 +42,6 @@ class ParentRequest: ...@@ -43,16 +42,6 @@ class ParentRequest:
self.max_num_generation_tokens = 0 self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None self.cached_child_sampling_params = None
@classmethod
def from_params(
cls,
request_id: str,
params: Union[SamplingParams, PoolingParams],
) -> Optional['ParentRequest']:
if not isinstance(params, SamplingParams) or params.n == 1:
return None
return cls(request_id, params)
def _get_child_sampling_params( def _get_child_sampling_params(
self, self,
index: int, index: int,
......
...@@ -173,7 +173,6 @@ class Processor: ...@@ -173,7 +173,6 @@ class Processor:
# 3. Apply prompt adapter to prompt token ids if one exists. # 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash, return_mm_hashes=self.use_hash,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment