Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7013a801
Unverified
Commit
7013a801
authored
Oct 30, 2023
by
Dan Lord
Committed by
GitHub
Oct 30, 2023
Browse files
Add support for `spaces_between_special_tokens`
parent
79a30912
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
7 deletions
+28
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-3
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+4
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+2
-0
vllm/sampling_params.py
vllm/sampling_params.py
+7
-1
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+12
-3
No files found.
vllm/engine/llm_engine.py
View file @
7013a801
...
@@ -632,8 +632,7 @@ class LLMEngine:
...
@@ -632,8 +632,7 @@ class LLMEngine:
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
self
.
last_logging_time
=
now
self
.
last_logging_time
=
now
def
_decode_sequence
(
self
,
seq
:
Sequence
,
def
_decode_sequence
(
self
,
seq
:
Sequence
,
prms
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
"""Decodes the new token for a sequence."""
"""Decodes the new token for a sequence."""
(
new_tokens
,
new_output_text
,
prefix_offset
,
(
new_tokens
,
new_output_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
read_offset
)
=
detokenize_incrementally
(
...
@@ -642,7 +641,8 @@ class LLMEngine:
...
@@ -642,7 +641,8 @@ class LLMEngine:
prev_tokens
=
seq
.
tokens
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
sampling_params
.
skip_special_tokens
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
)
if
seq
.
tokens
is
None
:
if
seq
.
tokens
is
None
:
seq
.
tokens
=
new_tokens
seq
.
tokens
=
new_tokens
...
...
vllm/entrypoints/openai/api_server.py
View file @
7013a801
...
@@ -212,6 +212,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -212,6 +212,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
monotonic
())
try
:
try
:
spaces_between_special_tokens
=
request
.
spaces_between_special_tokens
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
n
=
request
.
n
,
presence_penalty
=
request
.
presence_penalty
,
presence_penalty
=
request
.
presence_penalty
,
...
@@ -226,6 +227,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -226,6 +227,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
ignore_eos
=
request
.
ignore_eos
,
ignore_eos
=
request
.
ignore_eos
,
use_beam_search
=
request
.
use_beam_search
,
use_beam_search
=
request
.
use_beam_search
,
skip_special_tokens
=
request
.
skip_special_tokens
,
skip_special_tokens
=
request
.
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
...
@@ -413,6 +415,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -413,6 +415,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
monotonic
())
try
:
try
:
spaces_between_special_tokens
=
request
.
spaces_between_special_tokens
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
n
=
request
.
n
,
best_of
=
request
.
best_of
,
best_of
=
request
.
best_of
,
...
@@ -428,6 +431,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -428,6 +431,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
logprobs
=
request
.
logprobs
,
logprobs
=
request
.
logprobs
,
use_beam_search
=
request
.
use_beam_search
,
use_beam_search
=
request
.
use_beam_search
,
skip_special_tokens
=
request
.
skip_special_tokens
,
skip_special_tokens
=
request
.
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
...
...
vllm/entrypoints/openai/protocol.py
View file @
7013a801
...
@@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
use_beam_search
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
skip_special_tokens
:
Optional
[
bool
]
=
True
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
BaseModel
):
...
@@ -98,6 +99,7 @@ class CompletionRequest(BaseModel):
...
@@ -98,6 +99,7 @@ class CompletionRequest(BaseModel):
use_beam_search
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
skip_special_tokens
:
Optional
[
bool
]
=
True
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
class
LogProbs
(
BaseModel
):
class
LogProbs
(
BaseModel
):
...
...
vllm/sampling_params.py
View file @
7013a801
...
@@ -71,6 +71,8 @@ class SamplingParams:
...
@@ -71,6 +71,8 @@ class SamplingParams:
`logprobs+1` elements in the response.
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
prompt_logprobs: Number of log probabilities to return per prompt token.
skip_special_tokens: Whether to skip special tokens in the output.
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -93,6 +95,7 @@ class SamplingParams:
...
@@ -93,6 +95,7 @@ class SamplingParams:
logprobs
:
Optional
[
int
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
n
=
n
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
@@ -120,6 +123,7 @@ class SamplingParams:
...
@@ -120,6 +123,7 @@ class SamplingParams:
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
skip_special_tokens
=
skip_special_tokens
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
...
@@ -222,4 +226,6 @@ class SamplingParams:
...
@@ -222,4 +226,6 @@ class SamplingParams:
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
)"
)
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
"spaces_between_special_tokens="
f
"
{
self
.
spaces_between_special_tokens
}
)"
)
vllm/transformers_utils/tokenizer.py
View file @
7013a801
...
@@ -73,6 +73,7 @@ def _convert_tokens_to_string_with_added_encoders(
...
@@ -73,6 +73,7 @@ def _convert_tokens_to_string_with_added_encoders(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
output_tokens
:
List
[
str
],
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
)
->
str
:
)
->
str
:
# Adapted from
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
...
@@ -96,7 +97,10 @@ def _convert_tokens_to_string_with_added_encoders(
...
@@ -96,7 +97,10 @@ def _convert_tokens_to_string_with_added_encoders(
if
current_sub_text
:
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
sub_texts
.
append
(
sub_text
)
return
" "
.
join
(
sub_texts
)
if
spaces_between_special_tokens
:
return
" "
.
join
(
sub_texts
)
else
:
return
""
.
join
(
sub_texts
)
# Based on
# Based on
...
@@ -109,6 +113,7 @@ def detokenize_incrementally(
...
@@ -109,6 +113,7 @@ def detokenize_incrementally(
prefix_offset
:
int
=
0
,
prefix_offset
:
int
=
0
,
read_offset
:
int
=
0
,
read_offset
:
int
=
0
,
skip_special_tokens
:
bool
=
False
,
skip_special_tokens
:
bool
=
False
,
spaces_between_special_tokens
:
bool
=
True
,
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
new_token_id
=
all_input_ids
[
-
1
]
new_token_id
=
all_input_ids
[
-
1
]
# This is the first iteration for this sequence
# This is the first iteration for this sequence
...
@@ -143,11 +148,15 @@ def detokenize_incrementally(
...
@@ -143,11 +148,15 @@ def detokenize_incrementally(
prefix_text
=
_convert_tokens_to_string_with_added_encoders
(
prefix_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
tokenizer
,
output_tokens
[
prefix_offset
:
read_offset
],
output_tokens
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
)
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
new_text
=
_convert_tokens_to_string_with_added_encoders
(
new_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
tokenizer
,
output_tokens
[
prefix_offset
:],
output_tokens
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
)
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
if
len
(
new_text
)
>
len
(
prefix_text
)
and
not
new_text
.
endswith
(
"�"
):
if
len
(
new_text
)
>
len
(
prefix_text
)
and
not
new_text
.
endswith
(
"�"
):
# utf-8 char at the end means it's a potential unfinished byte sequence
# utf-8 char at the end means it's a potential unfinished byte sequence
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment