Unverified Commit 85eb6318 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Use slow tokenizer for LLaMA (#84)

parent add055e1
...@@ -7,12 +7,12 @@ from typing import List, Dict, Optional ...@@ -7,12 +7,12 @@ from typing import List, Dict, Optional
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import ray import ray
from transformers import AutoTokenizer
import uvicorn import uvicorn
from cacheflow.core.server import (Server, add_server_arguments, from cacheflow.core.server import (Server, add_server_arguments,
process_server_arguments, process_server_arguments,
initialize_cluster) initialize_cluster)
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
...@@ -44,7 +44,7 @@ class FastAPIServer: ...@@ -44,7 +44,7 @@ class FastAPIServer:
): ):
self.block_size = block_size self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model) self.tokenizer = get_tokenizer(model)
self.seq_group_counter = Counter() self.seq_group_counter = Counter()
self.seq_counter = Counter() self.seq_counter = Counter()
if server_use_ray: if server_use_ray:
......
import time import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from transformers import AutoTokenizer from cacheflow.frontend.utils import get_tokenizer
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.sequence import Sequence, SequenceGroup
...@@ -21,7 +20,7 @@ class SimpleFrontend: ...@@ -21,7 +20,7 @@ class SimpleFrontend:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = get_tokenizer(model_name)
self.seq_group_counter = Counter() self.seq_group_counter = Counter()
self.seq_counter = Counter() self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = [] self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
......
from typing import Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
"llama",
]
def get_tokenizer(
model_name: str,
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
config = AutoConfig.from_pretrained(model_name)
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment