Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d7c940d
"vscode:/vscode.git/clone" did not exist on "3c4cebf751a6d2ff9ada2f8234bab17ba7283e09"
Unverified
Commit
1d7c940d
authored
Apr 05, 2024
by
Thomas Parnell
Committed by
GitHub
Apr 05, 2024
Browse files
Add option to completion API to truncate prompt tokens (#3144)
parent
cfaf49a1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
8 deletions
+41
-8
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+3
-1
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+8
-2
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+18
-4
vllm/sampling_params.py
vllm/sampling_params.py
+12
-1
No files found.
vllm/entrypoints/openai/protocol.py
View file @
1d7c940d
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
import
torch
from
pydantic
import
BaseModel
,
Field
,
model_validator
from
pydantic
import
BaseModel
,
Field
,
conint
,
model_validator
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
...
@@ -229,6 +229,7 @@ class CompletionRequest(BaseModel):
...
@@ -229,6 +229,7 @@ class CompletionRequest(BaseModel):
min_tokens
:
Optional
[
int
]
=
0
min_tokens
:
Optional
[
int
]
=
0
skip_special_tokens
:
Optional
[
bool
]
=
True
skip_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
spaces_between_special_tokens
:
Optional
[
bool
]
=
True
truncate_prompt_tokens
:
Optional
[
conint
(
ge
=
1
)]
=
None
# doc: end-completion-sampling-params
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
# doc: begin-completion-extra-params
...
@@ -309,6 +310,7 @@ class CompletionRequest(BaseModel):
...
@@ -309,6 +310,7 @@ class CompletionRequest(BaseModel):
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
)
)
@
model_validator
(
mode
=
"before"
)
@
model_validator
(
mode
=
"before"
)
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
1d7c940d
...
@@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -137,10 +137,16 @@ class OpenAIServingCompletion(OpenAIServing):
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
prompt
in
enumerate
(
prompts
):
if
prompt_is_tokens
:
if
prompt_is_tokens
:
input_ids
=
self
.
_validate_prompt_and_tokenize
(
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
prompt
)
request
,
prompt_ids
=
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
)
else
:
else
:
input_ids
=
self
.
_validate_prompt_and_tokenize
(
input_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
request
,
prompt
=
prompt
,
truncate_prompt_tokens
=
sampling_params
.
truncate_prompt_tokens
)
generators
.
append
(
generators
.
append
(
self
.
engine
.
generate
(
prompt
,
self
.
engine
.
generate
(
prompt
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
1d7c940d
...
@@ -4,6 +4,8 @@ from dataclasses import dataclass
...
@@ -4,6 +4,8 @@ from dataclasses import dataclass
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
pydantic
import
conint
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
ErrorResponse
,
CompletionRequest
,
ErrorResponse
,
...
@@ -66,7 +68,8 @@ class OpenAIServing:
...
@@ -66,7 +68,8 @@ class OpenAIServing:
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer
=
get_tokenizer
(
engine_model_config
.
tokenizer
,
engine_model_config
.
tokenizer
,
tokenizer_mode
=
engine_model_config
.
tokenizer_mode
,
tokenizer_mode
=
engine_model_config
.
tokenizer_mode
,
trust_remote_code
=
engine_model_config
.
trust_remote_code
)
trust_remote_code
=
engine_model_config
.
trust_remote_code
,
truncation_side
=
"left"
)
async
def
show_available_models
(
self
)
->
ModelList
:
async
def
show_available_models
(
self
)
->
ModelList
:
"""Show available models. Right now we only have one model."""
"""Show available models. Right now we only have one model."""
...
@@ -164,15 +167,26 @@ class OpenAIServing:
...
@@ -164,15 +167,26 @@ class OpenAIServing:
self
,
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
conint
(
ge
=
1
)]
=
None
)
->
List
[
int
]:
if
not
(
prompt
or
prompt_ids
):
if
not
(
prompt
or
prompt_ids
):
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
if
(
prompt
and
prompt_ids
):
if
(
prompt
and
prompt_ids
):
raise
ValueError
(
raise
ValueError
(
"Only one of prompt or prompt_ids should be provided."
)
"Only one of prompt or prompt_ids should be provided."
)
input_ids
=
prompt_ids
if
prompt_ids
is
not
None
else
self
.
tokenizer
(
if
prompt_ids
is
None
:
prompt
).
input_ids
tokenizer_kwargs
=
{}
if
truncate_prompt_tokens
is
None
else
{
"truncation"
:
True
,
"max_length"
:
truncate_prompt_tokens
,
}
input_ids
=
self
.
tokenizer
(
prompt
,
**
tokenizer_kwargs
).
input_ids
elif
truncate_prompt_tokens
is
not
None
:
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
else
:
input_ids
=
prompt_ids
token_num
=
len
(
input_ids
)
token_num
=
len
(
input_ids
)
if
request
.
max_tokens
is
None
:
if
request
.
max_tokens
is
None
:
...
...
vllm/sampling_params.py
View file @
1d7c940d
...
@@ -5,6 +5,7 @@ from functools import cached_property
...
@@ -5,6 +5,7 @@ from functools import cached_property
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
import
torch
from
pydantic
import
conint
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -94,6 +95,9 @@ class SamplingParams:
...
@@ -94,6 +95,9 @@ class SamplingParams:
tokens in the output. Defaults to True.
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
logits_processors: List of functions that modify logits based on
previously generated tokens.
previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
"""
"""
def
__init__
(
def
__init__
(
...
@@ -123,6 +127,7 @@ class SamplingParams:
...
@@ -123,6 +127,7 @@ class SamplingParams:
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
conint
(
ge
=
1
)]
=
None
,
)
->
None
:
)
->
None
:
self
.
n
=
n
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
@@ -160,6 +165,7 @@ class SamplingParams:
...
@@ -160,6 +165,7 @@ class SamplingParams:
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
logits_processors
=
logits_processors
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
truncate_prompt_tokens
=
truncate_prompt_tokens
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
self
.
_verify_beam_search
()
...
@@ -216,6 +222,10 @@ class SamplingParams:
...
@@ -216,6 +222,10 @@ class SamplingParams:
if
self
.
prompt_logprobs
is
not
None
and
self
.
prompt_logprobs
<
0
:
if
self
.
prompt_logprobs
is
not
None
and
self
.
prompt_logprobs
<
0
:
raise
ValueError
(
f
"prompt_logprobs must be non-negative, got "
raise
ValueError
(
f
"prompt_logprobs must be non-negative, got "
f
"
{
self
.
prompt_logprobs
}
."
)
f
"
{
self
.
prompt_logprobs
}
."
)
if
(
self
.
truncate_prompt_tokens
is
not
None
and
self
.
truncate_prompt_tokens
<
1
):
raise
ValueError
(
f
"truncate_prompt_tokens must be >= 1, "
f
"got
{
self
.
truncate_prompt_tokens
}
"
)
if
self
.
stop
and
not
self
.
detokenize
:
if
self
.
stop
and
not
self
.
detokenize
:
raise
ValueError
(
raise
ValueError
(
"stop strings are only supported when detokenize is True. "
"stop strings are only supported when detokenize is True. "
...
@@ -300,4 +310,5 @@ class SamplingParams:
...
@@ -300,4 +310,5 @@ class SamplingParams:
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
f
"skip_special_tokens=
{
self
.
skip_special_tokens
}
, "
"spaces_between_special_tokens="
"spaces_between_special_tokens="
f
"
{
self
.
spaces_between_special_tokens
}
)"
)
f
"
{
self
.
spaces_between_special_tokens
}
, "
f
"truncate_prompt_tokens=
{
self
.
truncate_prompt_tokens
}
)"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment