Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
58a072be
Unverified
Commit
58a072be
authored
Jul 25, 2023
by
Zhuohan Li
Committed by
GitHub
Jul 25, 2023
Browse files
[Fix] Add model sequence length into model config (#575)
parent
82ad323d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
18 deletions
+27
-18
vllm/config.py
vllm/config.py
+20
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-3
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+5
-15
No files found.
vllm/config.py
View file @
58a072be
...
@@ -109,6 +109,26 @@ class ModelConfig:
...
@@ -109,6 +109,26 @@ class ModelConfig:
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
total_num_attention_heads
=
self
.
hf_config
.
num_attention_heads
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
return
total_num_attention_heads
//
parallel_config
.
tensor_parallel_size
def
get_max_model_len
(
self
)
->
int
:
max_model_len
=
float
(
"inf"
)
possible_keys
=
[
# OPT
"max_position_embeddings"
,
# GPT-2
"n_positions"
,
# MPT
"max_seq_len"
,
# Others
"max_sequence_length"
,
"max_seq_length"
,
"seq_len"
,
]
for
key
in
possible_keys
:
max_len_key
=
getattr
(
self
.
hf_config
,
key
,
None
)
if
max_len_key
is
not
None
:
max_model_len
=
min
(
max_model_len
,
max_len_key
)
return
max_model_len
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
total_num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
total_num_hidden_layers
=
self
.
hf_config
.
num_hidden_layers
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
...
...
vllm/engine/arg_utils.py
View file @
58a072be
...
@@ -155,10 +155,9 @@ class EngineArgs:
...
@@ -155,10 +155,9 @@ class EngineArgs:
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
)
self
.
worker_use_ray
)
max_model_len
=
getattr
(
model_config
.
hf_config
,
'max_position_embeddings'
,
float
(
'inf'
))
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
max_model_len
)
self
.
max_num_seqs
,
model_config
.
get_max_model_len
())
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
...
...
vllm/entrypoints/openai/api_server.py
View file @
58a072be
...
@@ -107,25 +107,14 @@ async def get_gen_prompt(request) -> str:
...
@@ -107,25 +107,14 @@ async def get_gen_prompt(request) -> str:
return
prompt
return
prompt
async
def
check_length
(
request
,
prompt
,
model_config
):
async
def
check_length
(
request
,
prompt
):
if
hasattr
(
model_config
.
hf_config
,
"max_sequence_length"
):
context_len
=
model_config
.
hf_config
.
max_sequence_length
elif
hasattr
(
model_config
.
hf_config
,
"seq_length"
):
context_len
=
model_config
.
hf_config
.
seq_length
elif
hasattr
(
model_config
.
hf_config
,
"max_position_embeddings"
):
context_len
=
model_config
.
hf_config
.
max_position_embeddings
elif
hasattr
(
model_config
.
hf_config
,
"seq_length"
):
context_len
=
model_config
.
hf_config
.
seq_length
else
:
context_len
=
2048
input_ids
=
tokenizer
(
prompt
).
input_ids
input_ids
=
tokenizer
(
prompt
).
input_ids
token_num
=
len
(
input_ids
)
token_num
=
len
(
input_ids
)
if
token_num
+
request
.
max_tokens
>
context
_len
:
if
token_num
+
request
.
max_tokens
>
max_model
_len
:
return
create_error_response
(
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
HTTPStatus
.
BAD_REQUEST
,
f
"This model's maximum context length is
{
context
_len
}
tokens. "
f
"This model's maximum context length is
{
max_model
_len
}
tokens. "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
f
"(
{
token_num
}
in the messages, "
f
"(
{
token_num
}
in the messages, "
f
"
{
request
.
max_tokens
}
in the completion). "
f
"
{
request
.
max_tokens
}
in the completion). "
...
@@ -194,7 +183,7 @@ async def create_chat_completion(raw_request: Request):
...
@@ -194,7 +183,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported"
)
"logit_bias is not currently supported"
)
prompt
=
await
get_gen_prompt
(
request
)
prompt
=
await
get_gen_prompt
(
request
)
error_check_ret
=
await
check_length
(
request
,
prompt
,
engine_model_config
)
error_check_ret
=
await
check_length
(
request
,
prompt
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
...
@@ -591,6 +580,7 @@ if __name__ == "__main__":
...
@@ -591,6 +580,7 @@ if __name__ == "__main__":
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine_model_config
=
asyncio
.
run
(
engine
.
get_model_config
())
engine_model_config
=
asyncio
.
run
(
engine
.
get_model_config
())
max_model_len
=
engine_model_config
.
get_max_model_len
()
# A separate tokenizer to map token IDs to strings.
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
engine_args
.
tokenizer
,
tokenizer
=
get_tokenizer
(
engine_args
.
tokenizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment