Unverified Commit 85eb6318 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Use slow tokenizer for LLaMA (#84)

parent add055e1
......@@ -7,12 +7,12 @@ from typing import List, Dict, Optional
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import ray
from transformers import AutoTokenizer
import uvicorn
from cacheflow.core.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
......@@ -44,7 +44,7 @@ class FastAPIServer:
):
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.tokenizer = get_tokenizer(model)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
if server_use_ray:
......
import time
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
......@@ -21,7 +20,7 @@ class SimpleFrontend:
) -> None:
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = get_tokenizer(model_name)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
......
from typing import Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
"llama",
]
def get_tokenizer(
model_name: str,
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
config = AutoConfig.from_pretrained(model_name)
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment