Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
20f7cc4c
Unverified
Commit
20f7cc4c
authored
Sep 27, 2023
by
Dan Lord
Committed by
GitHub
Sep 27, 2023
Browse files
Add `skip_special_tokens` sampling params (#1186)
parent
649aa730
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
4 deletions
+14
-4
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-3
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+2
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+2
-0
vllm/sampling_params.py
vllm/sampling_params.py
+6
-1
No files found.
vllm/engine/llm_engine.py
View file @
20f7cc4c
...
@@ -387,7 +387,7 @@ class LLMEngine:
...
@@ -387,7 +387,7 @@ class LLMEngine:
child_seqs
.
append
((
parent
,
parent
))
child_seqs
.
append
((
parent
,
parent
))
for
seq
,
_
in
child_seqs
:
for
seq
,
_
in
child_seqs
:
self
.
_decode_sequence
(
seq
)
self
.
_decode_sequence
(
seq
,
seq_group
.
sampling_params
)
self
.
_check_stop
(
seq
,
seq_group
.
sampling_params
)
self
.
_check_stop
(
seq
,
seq_group
.
sampling_params
)
# Non-beam search case
# Non-beam search case
...
@@ -621,7 +621,8 @@ class LLMEngine:
...
@@ -621,7 +621,8 @@ class LLMEngine:
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
self
.
last_logging_time
=
now
self
.
last_logging_time
=
now
def
_decode_sequence
(
self
,
seq
:
Sequence
)
->
None
:
def
_decode_sequence
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
"""Decodes the new token for a sequence."""
"""Decodes the new token for a sequence."""
(
new_tokens
,
new_output_text
,
prefix_offset
,
(
new_tokens
,
new_output_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
read_offset
)
=
detokenize_incrementally
(
...
@@ -630,7 +631,7 @@ class LLMEngine:
...
@@ -630,7 +631,7 @@ class LLMEngine:
prev_tokens
=
seq
.
tokens
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
True
,
skip_special_tokens
=
sampling_params
.
skip_special_tokens
,
)
)
if
seq
.
tokens
is
None
:
if
seq
.
tokens
is
None
:
seq
.
tokens
=
new_tokens
seq
.
tokens
=
new_tokens
...
...
vllm/entrypoints/openai/api_server.py
View file @
20f7cc4c
...
@@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
top_k
=
request
.
top_k
,
top_k
=
request
.
top_k
,
ignore_eos
=
request
.
ignore_eos
,
ignore_eos
=
request
.
ignore_eos
,
use_beam_search
=
request
.
use_beam_search
,
use_beam_search
=
request
.
use_beam_search
,
skip_special_tokens
=
request
.
skip_special_tokens
,
)
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
...
@@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
max_tokens
=
request
.
max_tokens
,
max_tokens
=
request
.
max_tokens
,
logprobs
=
request
.
logprobs
,
logprobs
=
request
.
logprobs
,
use_beam_search
=
request
.
use_beam_search
,
use_beam_search
=
request
.
use_beam_search
,
skip_special_tokens
=
request
.
skip_special_tokens
,
)
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
...
...
vllm/entrypoints/openai/protocol.py
View file @
20f7cc4c
...
@@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos
:
Optional
[
bool
]
=
False
ignore_eos
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
skip_special_tokens
:
Optional
[
bool
]
=
True
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
BaseModel
):
...
@@ -96,6 +97,7 @@ class CompletionRequest(BaseModel):
...
@@ -96,6 +97,7 @@ class CompletionRequest(BaseModel):
ignore_eos
:
Optional
[
bool
]
=
False
ignore_eos
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
use_beam_search
:
Optional
[
bool
]
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
skip_special_tokens
:
Optional
[
bool
]
=
True
class
LogProbs
(
BaseModel
):
class
LogProbs
(
BaseModel
):
...
...
vllm/sampling_params.py
View file @
20f7cc4c
...
@@ -60,6 +60,8 @@ class SamplingParams:
...
@@ -60,6 +60,8 @@ class SamplingParams:
tokens after the EOS token is generated.
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
logprobs: Number of log probabilities to return per output token.
skip_special_tokens: Whether to skip special tokens in the output.
Defaults to true.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -79,6 +81,7 @@ class SamplingParams:
...
@@ -79,6 +81,7 @@ class SamplingParams:
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
logprobs
:
Optional
[
int
]
=
None
,
logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
n
=
n
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
@@ -103,6 +106,7 @@ class SamplingParams:
...
@@ -103,6 +106,7 @@ class SamplingParams:
self
.
ignore_eos
=
ignore_eos
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
max_tokens
=
max_tokens
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
self
.
skip_special_tokens
=
skip_special_tokens
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
...
@@ -196,4 +200,5 @@ class SamplingParams:
...
@@ -196,4 +200,5 @@ class SamplingParams:
f
"stop=
{
self
.
stop
}
, "
f
"stop=
{
self
.
stop
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"ignore_eos=
{
self
.
ignore_eos
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"max_tokens=
{
self
.
max_tokens
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
f
"logprobs=
{
self
.
logprobs
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
)"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment