Commit 3bda0405 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Bugfix] TypeError: ChatGLMTokenizer._pad() got an unexpected keyword argument 'padding_side'

parent 91577443
import os
import warnings
from pathlib import Path
from types import MethodType
from typing import Optional, Union
import huggingface_hub
......@@ -10,8 +11,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
MistralTokenizer)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async
......@@ -59,6 +59,26 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
return tokenizer
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
"""Patch _pad method to accept `padding_side` for older tokenizers."""
orig_pad = tokenizer._pad
def _pad(
self: PreTrainedTokenizer,
*args,
padding_side: Optional[str] = None,
**kwargs,
):
if padding_side is not None and padding_side != self.padding_side:
msg = ("`padding_side` argument is not supported by "
f"{type(tokenizer).__name__} and will be ignored.")
warnings.warn(msg, stacklevel=2)
return orig_pad(*args, **kwargs)
tokenizer._pad = MethodType(_pad, tokenizer)
def get_tokenizer(
tokenizer_name: Union[str, Path],
*args,
......@@ -138,19 +158,12 @@ def get_tokenizer(
raise RuntimeError(err_msg) from e
else:
raise e
except AttributeError as e:
if "BaichuanTokenizer" in str(e):
# This is for the error "'BaichuanTokenizer' object has no
# attribute 'sp_model'".
tokenizer = BaichuanTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
else:
raise e
# NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
if type(tokenizer).__name__ in ("ChatGLMTokenizer",
"ChatGLM4Tokenizer"):
assert isinstance(tokenizer, PreTrainedTokenizer)
patch_padding_side(tokenizer)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
......@@ -167,7 +180,7 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
return None
try:
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
except OSError as e:
except Exception as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment