Unverified Commit 391d7b27 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix usage of `deprecated` decorator (#11025)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d1f6d1c8
...@@ -677,12 +677,10 @@ class LLMEngine: ...@@ -677,12 +677,10 @@ class LLMEngine:
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
@overload @overload
@deprecated("'inputs' will be renamed to 'prompt")
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
*, prompt: PromptType,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -693,10 +691,12 @@ class LLMEngine: ...@@ -693,10 +691,12 @@ class LLMEngine:
... ...
@overload @overload
@deprecated("'inputs' will be renamed to 'prompt")
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
prompt: PromptType, *,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
......
...@@ -35,11 +35,9 @@ class RPCProcessRequest: ...@@ -35,11 +35,9 @@ class RPCProcessRequest:
priority: int = 0 priority: int = 0
@overload @overload
@deprecated("'inputs' will be renamed to 'prompt")
def __init__( def __init__(
self, self,
*, prompt: PromptType,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -50,9 +48,11 @@ class RPCProcessRequest: ...@@ -50,9 +48,11 @@ class RPCProcessRequest:
... ...
@overload @overload
@deprecated("'inputs' will be renamed to 'prompt")
def __init__( def __init__(
self, self,
prompt: PromptType, *,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
......
...@@ -415,11 +415,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -415,11 +415,9 @@ class MQLLMEngineClient(EngineClient):
return ENGINE_DEAD_ERROR(self._errored_with) return ENGINE_DEAD_ERROR(self._errored_with)
@overload @overload
@deprecated("'inputs' will be renamed to 'prompt")
def generate( def generate(
self, self,
*, prompt: PromptType,
inputs: PromptType,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -430,9 +428,11 @@ class MQLLMEngineClient(EngineClient): ...@@ -430,9 +428,11 @@ class MQLLMEngineClient(EngineClient):
... ...
@overload @overload
@deprecated("'inputs' will be renamed to 'prompt")
def generate( def generate(
self, self,
prompt: PromptType, *,
inputs: PromptType,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -487,11 +487,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -487,11 +487,9 @@ class MQLLMEngineClient(EngineClient):
prompt_adapter_request, priority) prompt_adapter_request, priority)
@overload @overload
@deprecated("'inputs' will be renamed to 'prompt")
def encode( def encode(
self, self,
*, prompt: PromptType,
inputs: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -501,9 +499,11 @@ class MQLLMEngineClient(EngineClient): ...@@ -501,9 +499,11 @@ class MQLLMEngineClient(EngineClient):
... ...
@overload @overload
@deprecated("'inputs' will be renamed to 'prompt")
def encode( def encode(
self, self,
prompt: PromptType, *,
inputs: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
......
...@@ -252,8 +252,21 @@ class LLM: ...@@ -252,8 +252,21 @@ class LLM:
else: else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
@overload
def generate(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single (prompt + optional token ids) @overload # LEGACY: single (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate( def generate(
self, self,
prompts: str, prompts: str,
...@@ -266,7 +279,7 @@ class LLM: ...@@ -266,7 +279,7 @@ class LLM:
... ...
@overload # LEGACY: multi (prompt + optional token ids) @overload # LEGACY: multi (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
...@@ -279,7 +292,7 @@ class LLM: ...@@ -279,7 +292,7 @@ class LLM:
... ...
@overload # LEGACY: single (token ids + optional prompt) @overload # LEGACY: single (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate( def generate(
self, self,
prompts: Optional[str] = None, prompts: Optional[str] = None,
...@@ -293,7 +306,7 @@ class LLM: ...@@ -293,7 +306,7 @@ class LLM:
... ...
@overload # LEGACY: multi (token ids + optional prompt) @overload # LEGACY: multi (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate( def generate(
self, self,
prompts: Optional[List[str]] = None, prompts: Optional[List[str]] = None,
...@@ -307,7 +320,7 @@ class LLM: ...@@ -307,7 +320,7 @@ class LLM:
... ...
@overload # LEGACY: single or multi token ids [pos-only] @overload # LEGACY: single or multi token ids [pos-only]
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def generate( def generate(
self, self,
prompts: None, prompts: None,
...@@ -318,19 +331,6 @@ class LLM: ...@@ -318,19 +331,6 @@ class LLM:
) -> List[RequestOutput]: ) -> List[RequestOutput]:
... ...
@overload
def generate(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@deprecate_kwargs( @deprecate_kwargs(
"prompt_token_ids", "prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
...@@ -672,8 +672,21 @@ class LLM: ...@@ -672,8 +672,21 @@ class LLM:
lora_request=lora_request, lora_request=lora_request,
) )
@overload
def encode(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: single (prompt + optional token ids) @overload # LEGACY: single (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode( def encode(
self, self,
prompts: str, prompts: str,
...@@ -686,7 +699,7 @@ class LLM: ...@@ -686,7 +699,7 @@ class LLM:
... ...
@overload # LEGACY: multi (prompt + optional token ids) @overload # LEGACY: multi (prompt + optional token ids)
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode( def encode(
self, self,
prompts: List[str], prompts: List[str],
...@@ -699,7 +712,7 @@ class LLM: ...@@ -699,7 +712,7 @@ class LLM:
... ...
@overload # LEGACY: single (token ids + optional prompt) @overload # LEGACY: single (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode( def encode(
self, self,
prompts: Optional[str] = None, prompts: Optional[str] = None,
...@@ -713,7 +726,7 @@ class LLM: ...@@ -713,7 +726,7 @@ class LLM:
... ...
@overload # LEGACY: multi (token ids + optional prompt) @overload # LEGACY: multi (token ids + optional prompt)
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode( def encode(
self, self,
prompts: Optional[List[str]] = None, prompts: Optional[List[str]] = None,
...@@ -727,7 +740,7 @@ class LLM: ...@@ -727,7 +740,7 @@ class LLM:
... ...
@overload # LEGACY: single or multi token ids [pos-only] @overload # LEGACY: single or multi token ids [pos-only]
@deprecated("'prompt_token_ids' will become part of 'prompts") @deprecated("'prompt_token_ids' will become part of 'prompts'")
def encode( def encode(
self, self,
prompts: None, prompts: None,
...@@ -738,19 +751,6 @@ class LLM: ...@@ -738,19 +751,6 @@ class LLM:
) -> List[PoolingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload
def encode(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[PoolingRequestOutput]:
...
@deprecate_kwargs( @deprecate_kwargs(
"prompt_token_ids", "prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment