Unverified Commit cfaa6008 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Access `get_vocab` instead of `vocab` in tool parsers (#9188)

parent 21906a6f
import importlib import importlib
import importlib.util import importlib.util
import os import os
from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Type, Union from typing import Callable, Dict, List, Optional, Sequence, Type, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...@@ -29,6 +30,12 @@ class ToolParser: ...@@ -29,6 +30,12 @@ class ToolParser:
self.model_tokenizer = tokenizer self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> Dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request( def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest: self, request: ChatCompletionRequest) -> ChatCompletionRequest:
""" """
......
...@@ -50,10 +50,9 @@ class Hermes2ProToolParser(ToolParser): ...@@ -50,10 +50,9 @@ class Hermes2ProToolParser(ToolParser):
raise ValueError( raise ValueError(
"The model tokenizer must be passed to the ToolParser " "The model tokenizer must be passed to the ToolParser "
"constructor during construction.") "constructor during construction.")
self.tool_call_start_token_id: int = self.model_tokenizer.vocab.get( self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token, None) self.tool_call_start_token)
self.tool_call_end_token_id: int = self.model_tokenizer.vocab.get( self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self.tool_call_end_token, None)
if not self.tool_call_start_token_id or not self.tool_call_end_token_id: if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError( raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end " "Hermes 2 Pro Tool parser could not locate tool call start/end "
......
...@@ -61,8 +61,7 @@ class MistralToolParser(ToolParser): ...@@ -61,8 +61,7 @@ class MistralToolParser(ToolParser):
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]" self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.get_vocab().get( self.bot_token_id = self.vocab.get(self.bot_token)
self.bot_token, None)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
if not self.bot_token_id: if not self.bot_token_id:
raise RuntimeError( raise RuntimeError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment