Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7013a801
"docs/vscode:/vscode.git/clone" did not exist on "741777b5b5bdfb498173d3911bdf4978b49a12ec"
Unverified
Commit
7013a801
authored
Oct 30, 2023
by
Dan Lord
Committed by
GitHub
Oct 30, 2023
Browse files
Add support for `spaces_between_special_tokens`
parent
79a30912
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
7 deletions
+28
-7
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-3
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+4
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+2
-0
vllm/sampling_params.py
vllm/sampling_params.py
+7
-1
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+12
-3
No files found.
vllm/engine/llm_engine.py
View file @
7013a801
...
@@ -632,8 +632,7 @@ class LLMEngine:
...
@@ -632,8 +632,7 @@ class LLMEngine:
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
self
.
last_logging_time
=
now
self
.
last_logging_time
=
now
def
_decode_sequence
(
self
,
seq
:
Sequence
,
def
_decode_sequence
(
self
,
seq
:
Sequence
,
prms
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
"""Decodes the new token for a sequence."""
"""Decodes the new token for a sequence."""
(
new_tokens
,
new_output_text
,
prefix_offset
,
(
new_tokens
,
new_output_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
read_offset
)
=
detokenize_incrementally
(
...
@@ -642,7 +641,8 @@ class LLMEngine:
...
@@ -642,7 +641,8 @@ class LLMEngine:
prev_tokens
=
seq
.
tokens
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
sampling_params
.
skip_special_tokens
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
)
if
seq
.
tokens
is
None
:
if
seq
.
tokens
is
None
:
seq
.
tokens
=
new_tokens
seq
.
tokens
=
new_tokens
...
...
vllm/entrypoints/openai/api_server.py
View file @
7013a801
...
@@ -212,6 +212,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -212,6 +212,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
monotonic
())
try
:
try
:
spaces_between_special_tokens
=
request
.
spaces_between_special_tokens
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
n
=
request
.
n
,
presence_penalty
=
request
.
presence_penalty
,
presence_penalty
=
request
.
presence_penalty
,
...
@@ -226,6 +227,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -226,6 +227,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
ignore_eos
=
request
.
ignore_eos
,
ignore_eos
=
request
.
ignore_eos
,
use_beam_search
=
request
.
use_beam_search
,
use_beam_search
=
request
.
use_beam_search
,
skip_special_tokens
=
request
.
skip_special_tokens
,
skip_special_tokens
=
request
.
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
...
@@ -413,6 +415,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -413,6 +415,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
created_time
=
int
(
time
.
monotonic
())
created_time
=
int
(
time
.
monotonic
())
try
:
try
:
spaces_between_special_tokens
=
request
.
spaces_between_special_tokens
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
n
=
request
.
n
,
best_of
=
request
.
best_of
,
best_of
=
request
.
best_of
,
...
@@ -428,6 +431,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -428,6 +431,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
logprobs
=
request
.
logprobs
,
logprobs
=
request
.
logprobs
,
use_beam_search
=
request
.
use_beam_search
,
use_beam_search
=
request
.
use_beam_search
,
skip_special_tokens
=
request
.
skip_special_tokens
,
skip_special_tokens
=
request
.
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
...
...
vllm/entrypoints/openai/protocol.py
View file @
7013a801
...
@@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
use_beam_search
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
skip_special_tokens
:
Optional
[
bool
]
=
True
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
BaseModel
):
...
@@ -98,6 +99,7 @@ class CompletionRequest(BaseModel):
...
@@ -98,6 +99,7 @@ class CompletionRequest(BaseModel):
use_beam_search
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
skip_special_tokens
:
Optional
[
bool
]
=
True
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
class
LogProbs
(
BaseModel
):
class
LogProbs
(
BaseModel
):
...
...
vllm/sampling_params.py
View file @
7013a801
...
@@ -71,6 +71,8 @@ class SamplingParams:
...
@@ -71,6 +71,8 @@ class SamplingParams:
`logprobs+1` elements in the response.
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
prompt_logprobs: Number of log probabilities to return per prompt token.
skip_special_tokens: Whether to skip special tokens in the output.
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -93,6 +95,7 @@ class SamplingParams:
...
@@ -93,6 +95,7 @@ class SamplingParams:
logprobs
:
Optional
[
int
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
n
=
n
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
@@ -120,6 +123,7 @@ class SamplingParams:
...
@@ -120,6 +123,7 @@ class SamplingParams:
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
skip_special_tokens
=
skip_special_tokens
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
...
@@ -222,4 +226,6 @@ class SamplingParams:
...
@@ -222,4 +226,6 @@ class SamplingParams:
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
f
"logprobs=
{
self
.
logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
)"
)
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
"spaces_between_special_tokens="
f
"
{
self
.
spaces_between_special_tokens
}
)"
)
vllm/transformers_utils/tokenizer.py
View file @
7013a801
...
@@ -73,6 +73,7 @@ def _convert_tokens_to_string_with_added_encoders(
...
@@ -73,6 +73,7 @@ def _convert_tokens_to_string_with_added_encoders(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
output_tokens
:
List
[
str
],
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
)
->
str
:
)
->
str
:
# Adapted from
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
...
@@ -96,7 +97,10 @@ def _convert_tokens_to_string_with_added_encoders(
...
@@ -96,7 +97,10 @@ def _convert_tokens_to_string_with_added_encoders(
if
current_sub_text
:
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
sub_texts
.
append
(
sub_text
)
if
spaces_between_special_tokens
:
return
" "
.
join
(
sub_texts
)
return
" "
.
join
(
sub_texts
)
else
:
return
""
.
join
(
sub_texts
)
# Based on
# Based on
...
@@ -109,6 +113,7 @@ def detokenize_incrementally(
...
@@ -109,6 +113,7 @@ def detokenize_incrementally(
prefix_offset
:
int
=
0
,
prefix_offset
:
int
=
0
,
read_offset
:
int
=
0
,
read_offset
:
int
=
0
,
skip_special_tokens
:
bool
=
False
,
skip_special_tokens
:
bool
=
False
,
spaces_between_special_tokens
:
bool
=
True
,
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
new_token_id
=
all_input_ids
[
-
1
]
new_token_id
=
all_input_ids
[
-
1
]
# This is the first iteration for this sequence
# This is the first iteration for this sequence
...
@@ -143,11 +148,15 @@ def detokenize_incrementally(
...
@@ -143,11 +148,15 @@ def detokenize_incrementally(
prefix_text
=
_convert_tokens_to_string_with_added_encoders
(
prefix_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
tokenizer
,
output_tokens
[
prefix_offset
:
read_offset
],
output_tokens
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
)
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
new_text
=
_convert_tokens_to_string_with_added_encoders
(
new_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
tokenizer
,
output_tokens
[
prefix_offset
:],
output_tokens
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
)
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
if
len
(
new_text
)
>
len
(
prefix_text
)
and
not
new_text
.
endswith
(
"�"
):
if
len
(
new_text
)
>
len
(
prefix_text
)
and
not
new_text
.
endswith
(
"�"
):
# utf-8 char at the end means it's a potential unfinished byte sequence
# utf-8 char at the end means it's a potential unfinished byte sequence
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment