"docs/vscode:/vscode.git/clone" did not exist on "24d6ea8afdb13ceee95b36645ba61a641f9a2f7f"
Unverified Commit 625ccd1c authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[Bugfix] Replace custom Encoding class with BatchEncoding in MistralTokenizer (#22786)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent 92ff41ab
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
import huggingface_hub import huggingface_hub
import regex as re import regex as re
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from transformers.tokenization_utils_base import BatchEncoding
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase from vllm.transformers_utils.tokenizer_base import TokenizerBase
...@@ -27,11 +27,6 @@ if TYPE_CHECKING: ...@@ -27,11 +27,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class Encoding:
input_ids: Union[list[int], list[list[int]]]
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
# SEE: https://github.com/vllm-project/vllm/pull/9951 # SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes # Credits go to: @gcalmettes
...@@ -359,7 +354,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -359,7 +354,7 @@ class MistralTokenizer(TokenizerBase):
# For str, single prompt text # For str, single prompt text
else: else:
input_ids = self.encode_one(text, truncation, max_length) input_ids = self.encode_one(text, truncation, max_length)
return Encoding(input_ids=input_ids) return BatchEncoding({"input_ids": input_ids})
def get_vocab(self) -> dict[str, int]: def get_vocab(self) -> dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map # NB: the dictionary form of the vocabulary collapses token ids that map
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment