Commit 6eb31507 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix: mypy error (#543)


Co-authored-by: default avatarfinofliu <finofliu@tencent.com>
parent 99cc11e6
...@@ -41,7 +41,8 @@ from tensorrt_llm.serve.openai_protocol import ( ...@@ -41,7 +41,8 @@ from tensorrt_llm.serve.openai_protocol import (
ToolCall, ToolCall,
UsageInfo, UsageInfo,
) )
from transformers import AutoTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
logger.set_level("debug") logger.set_level("debug")
...@@ -71,7 +72,11 @@ def parse_chat_message_content( ...@@ -71,7 +72,11 @@ def parse_chat_message_content(
class BaseChatProcessor: class BaseChatProcessor:
def __init__(self, model: str, tokenizer: AutoTokenizer): def __init__(
self,
model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -122,7 +127,10 @@ class BaseChatProcessor: ...@@ -122,7 +127,10 @@ class BaseChatProcessor:
class ChatProcessor(BaseChatProcessor): class ChatProcessor(BaseChatProcessor):
def __init__( def __init__(
self, model: str, tokenizer: AutoTokenizer, using_engine_generator: bool = False self,
model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
using_engine_generator: bool = False,
): ):
super().__init__(model, tokenizer) super().__init__(model, tokenizer)
self.using_engine_generator = using_engine_generator self.using_engine_generator = using_engine_generator
...@@ -269,7 +277,7 @@ class ChatProcessor(BaseChatProcessor): ...@@ -269,7 +277,7 @@ class ChatProcessor(BaseChatProcessor):
if request.tools is None if request.tools is None
else [tool.model_dump() for tool in request.tools] else [tool.model_dump() for tool in request.tools]
) )
prompt: str = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
tokenize=False, tokenize=False,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
...@@ -329,7 +337,11 @@ class ChatProcessor(BaseChatProcessor): ...@@ -329,7 +337,11 @@ class ChatProcessor(BaseChatProcessor):
class CompletionsProcessor: class CompletionsProcessor:
def __init__(self, model: str, tokenizer: AutoTokenizer): def __init__(
self,
model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment