Unverified Commit 64448248 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Implement `TokenizerLike.convert_tokens_to_ids` (#31796)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bf0f3a46
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, overload
from transformers import BatchEncoding from transformers import BatchEncoding
...@@ -65,6 +65,7 @@ class DeepseekV32Tokenizer(CachedHfTokenizer): ...@@ -65,6 +65,7 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
drop_thinking = messages[-1]["role"] == "user" drop_thinking = messages[-1]["role"] == "user"
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
prompt_str = encode_messages(messages, **encode_config) # type: ignore prompt_str = encode_messages(messages, **encode_config) # type: ignore
if kwargs.get("tokenize", True): if kwargs.get("tokenize", True):
...@@ -161,6 +162,15 @@ class DeepseekV32Tokenizer(CachedHfTokenizer): ...@@ -161,6 +162,15 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
) )
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
return self.tokenizer.convert_tokens_to_ids(tokens)
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
return self.tokenizer.convert_tokens_to_string(tokens) return self.tokenizer.convert_tokens_to_string(tokens)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast, overload
from mistral_common.protocol.instruct.request import ( from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest, ChatCompletionRequest as MistralChatCompletionRequest,
...@@ -441,6 +441,15 @@ class MistralTokenizer(TokenizerLike): ...@@ -441,6 +441,15 @@ class MistralTokenizer(TokenizerLike):
ids, skip_special_tokens=skip_special_tokens ids, skip_special_tokens=skip_special_tokens
) )
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
return self.transformers_tokenizer.convert_tokens_to_ids(tokens)
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
to_decode_special_tokens = {SpecialTokens.tool_calls} to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken: if self.is_tekken:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol from typing import TYPE_CHECKING, Any, Protocol, overload
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import BatchEncoding from transformers import BatchEncoding
...@@ -100,6 +100,15 @@ class TokenizerLike(Protocol): ...@@ -100,6 +100,15 @@ class TokenizerLike(Protocol):
) -> str | list[int]: ) -> str | list[int]:
raise NotImplementedError raise NotImplementedError
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
raise NotImplementedError
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment