Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
85eb6318
Unverified
Commit
85eb6318
authored
May 09, 2023
by
Woosuk Kwon
Committed by
GitHub
May 09, 2023
Browse files
Use slow tokenizer for LLaMA (#84)
parent
add055e1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
5 deletions
+26
-5
cacheflow/frontend/fastapi_frontend.py
cacheflow/frontend/fastapi_frontend.py
+2
-2
cacheflow/frontend/simple_frontend.py
cacheflow/frontend/simple_frontend.py
+2
-3
cacheflow/frontend/utils.py
cacheflow/frontend/utils.py
+22
-0
No files found.
cacheflow/frontend/fastapi_frontend.py
View file @
85eb6318
...
@@ -7,12 +7,12 @@ from typing import List, Dict, Optional
...
@@ -7,12 +7,12 @@ from typing import List, Dict, Optional
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
import
ray
import
ray
from
transformers
import
AutoTokenizer
import
uvicorn
import
uvicorn
from
cacheflow.core.server
import
(
Server
,
add_server_arguments
,
from
cacheflow.core.server
import
(
Server
,
add_server_arguments
,
process_server_arguments
,
process_server_arguments
,
initialize_cluster
)
initialize_cluster
)
from
cacheflow.frontend.utils
import
get_tokenizer
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.utils
import
Counter
,
get_gpu_memory
,
get_cpu_memory
from
cacheflow.utils
import
Counter
,
get_gpu_memory
,
get_cpu_memory
...
@@ -44,7 +44,7 @@ class FastAPIServer:
...
@@ -44,7 +44,7 @@ class FastAPIServer:
):
):
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
tokenizer
=
AutoT
okenizer
.
from_pretrained
(
model
)
self
.
tokenizer
=
get_t
okenizer
(
model
)
self
.
seq_group_counter
=
Counter
()
self
.
seq_group_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
if
server_use_ray
:
if
server_use_ray
:
...
...
cacheflow/frontend/simple_frontend.py
View file @
85eb6318
import
time
import
time
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
transformers
import
AutoTokenizer
from
cacheflow.frontend.utils
import
get_tokenizer
from
cacheflow.logger
import
init_logger
from
cacheflow.logger
import
init_logger
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
...
@@ -21,7 +20,7 @@ class SimpleFrontend:
...
@@ -21,7 +20,7 @@ class SimpleFrontend:
)
->
None
:
)
->
None
:
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
tokenizer
=
AutoT
okenizer
.
from_pretrained
(
model_name
)
self
.
tokenizer
=
get_t
okenizer
(
model_name
)
self
.
seq_group_counter
=
Counter
()
self
.
seq_group_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
self
.
inputs
:
List
[
Tuple
[
SequenceGroup
,
SamplingParams
]]
=
[]
self
.
inputs
:
List
[
Tuple
[
SequenceGroup
,
SamplingParams
]]
=
[]
...
...
cacheflow/frontend/utils.py
0 → 100644
View file @
85eb6318
from
typing
import
Union
from
transformers
import
(
AutoConfig
,
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
_MODEL_TYPES_WITH_SLOW_TOKENIZER
=
[
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
"llama"
,
]
def
get_tokenizer
(
model_name
:
str
,
*
args
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
config
.
model_type
in
_MODEL_TYPES_WITH_SLOW_TOKENIZER
:
kwargs
[
"use_fast"
]
=
False
return
AutoTokenizer
.
from_pretrained
(
model_name
,
*
args
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment