Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9b9a10d6
Unverified
Commit
9b9a10d6
authored
May 22, 2024
by
sasha0552
Committed by
GitHub
May 22, 2024
Browse files
[Frontend] Dynamic RoPE scaling (#4638)
parent
99eff67b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
89 additions
and
12 deletions
+89
-12
tests/test_config.py
tests/test_config.py
+55
-1
vllm/config.py
vllm/config.py
+6
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+13
-5
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+6
-4
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+9
-1
No files found.
tests/test_config.py
View file @
9b9a10d6
...
@@ -37,3 +37,57 @@ def test_get_sliding_window():
...
@@ -37,3 +37,57 @@ def test_get_sliding_window():
mistral_model_config
.
hf_config
.
sliding_window
=
TEST_SLIDING_WINDOW
mistral_model_config
.
hf_config
.
sliding_window
=
TEST_SLIDING_WINDOW
assert
mistral_model_config
.
get_sliding_window
()
==
TEST_SLIDING_WINDOW
assert
mistral_model_config
.
get_sliding_window
()
==
TEST_SLIDING_WINDOW
def
test_rope_scaling
():
TEST_ROPE_SCALING
=
{
"type"
:
"dynamic"
,
"factor"
:
2.0
}
LONGCHAT_ROPE_SCALING
=
{
"type"
:
"linear"
,
"factor"
:
8.0
}
llama_model_config
=
ModelConfig
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"meta-llama/Meta-Llama-3-8B-Instruct"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
seed
=
0
,
)
assert
getattr
(
llama_model_config
.
hf_config
,
"rope_scaling"
,
None
)
is
None
assert
llama_model_config
.
max_model_len
==
8192
llama_model_config
=
ModelConfig
(
"meta-llama/Meta-Llama-3-8B-Instruct"
,
"meta-llama/Meta-Llama-3-8B-Instruct"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
seed
=
0
,
rope_scaling
=
TEST_ROPE_SCALING
,
)
assert
getattr
(
llama_model_config
.
hf_config
,
"rope_scaling"
,
None
)
==
TEST_ROPE_SCALING
assert
llama_model_config
.
max_model_len
==
16384
longchat_model_config
=
ModelConfig
(
"lmsys/longchat-13b-16k"
,
"lmsys/longchat-13b-16k"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
seed
=
0
,
)
assert
getattr
(
longchat_model_config
.
hf_config
,
"rope_scaling"
,
None
)
==
LONGCHAT_ROPE_SCALING
assert
longchat_model_config
.
max_model_len
==
16384
longchat_model_config
=
ModelConfig
(
"lmsys/longchat-13b-16k"
,
"lmsys/longchat-13b-16k"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float16"
,
seed
=
0
,
rope_scaling
=
TEST_ROPE_SCALING
,
)
assert
getattr
(
longchat_model_config
.
hf_config
,
"rope_scaling"
,
None
)
==
TEST_ROPE_SCALING
assert
longchat_model_config
.
max_model_len
==
4096
vllm/config.py
View file @
9b9a10d6
...
@@ -45,6 +45,9 @@ class ModelConfig:
...
@@ -45,6 +45,9 @@ class ModelConfig:
code_revision: The specific revision to use for the model code on
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
commit id. If unspecified, will use the default version.
rope_scaling: Dictionary containing the scaling configuration for the
RoPE embeddings. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
tokenizer_revision: The specific tokenizer version to use. It can be a
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
the default version.
...
@@ -84,6 +87,7 @@ class ModelConfig:
...
@@ -84,6 +87,7 @@ class ModelConfig:
seed
:
int
,
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
...
@@ -102,6 +106,7 @@ class ModelConfig:
...
@@ -102,6 +106,7 @@ class ModelConfig:
self
.
seed
=
seed
self
.
seed
=
seed
self
.
revision
=
revision
self
.
revision
=
revision
self
.
code_revision
=
code_revision
self
.
code_revision
=
code_revision
self
.
rope_scaling
=
rope_scaling
self
.
tokenizer_revision
=
tokenizer_revision
self
.
tokenizer_revision
=
tokenizer_revision
self
.
quantization
=
quantization
self
.
quantization
=
quantization
self
.
quantization_param_path
=
quantization_param_path
self
.
quantization_param_path
=
quantization_param_path
...
@@ -116,7 +121,7 @@ class ModelConfig:
...
@@ -116,7 +121,7 @@ class ModelConfig:
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
)
code_revision
,
rope_scaling
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_text_config
,
dtype
)
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
hf_text_config
,
...
...
vllm/engine/arg_utils.py
View file @
9b9a10d6
import
argparse
import
argparse
import
dataclasses
import
dataclasses
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
@@ -49,6 +50,7 @@ class EngineArgs:
...
@@ -49,6 +50,7 @@ class EngineArgs:
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
code_revision
:
Optional
[
str
]
=
None
code_revision
:
Optional
[
str
]
=
None
rope_scaling
:
Optional
[
dict
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
enforce_eager
:
bool
=
False
enforce_eager
:
bool
=
False
...
@@ -330,6 +332,11 @@ class EngineArgs:
...
@@ -330,6 +332,11 @@ class EngineArgs:
'None, we assume the model weights are not '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'quantized and use `dtype` to determine the data '
'type of the weights.'
)
'type of the weights.'
)
parser
.
add_argument
(
'--rope-scaling'
,
default
=
None
,
type
=
json
.
loads
,
help
=
'RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}'
)
parser
.
add_argument
(
'--enforce-eager'
,
parser
.
add_argument
(
'--enforce-eager'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Always use eager-mode PyTorch. If False, '
help
=
'Always use eager-mode PyTorch. If False, '
...
@@ -548,11 +555,12 @@ class EngineArgs:
...
@@ -548,11 +555,12 @@ class EngineArgs:
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
trust_remote_code
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
code_revision
,
self
.
rope_scaling
,
self
.
tokenizer_revision
,
self
.
quantization
,
self
.
quantization_param_path
,
self
.
max_model_len
,
self
.
quantization
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
quantization_param_path
,
self
.
enforce_eager
,
self
.
max_seq_len_to_capture
,
self
.
max_logprobs
,
self
.
max_context_len_to_capture
,
self
.
max_seq_len_to_capture
,
self
.
skip_tokenizer_init
,
self
.
served_model_name
)
self
.
max_logprobs
,
self
.
skip_tokenizer_init
,
self
.
served_model_name
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
...
...
vllm/engine/llm_engine.py
View file @
9b9a10d6
...
@@ -104,10 +104,11 @@ class LLMEngine:
...
@@ -104,10 +104,11 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: "
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"rope_scaling=%r, tokenizer_revision=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)"
,
"decoding_config=%r, seed=%d, served_model_name=%s)"
,
vllm
.
__version__
,
vllm
.
__version__
,
...
@@ -117,6 +118,7 @@ class LLMEngine:
...
@@ -117,6 +118,7 @@ class LLMEngine:
model_config
.
skip_tokenizer_init
,
model_config
.
skip_tokenizer_init
,
model_config
.
tokenizer_mode
,
model_config
.
tokenizer_mode
,
model_config
.
revision
,
model_config
.
revision
,
model_config
.
rope_scaling
,
model_config
.
tokenizer_revision
,
model_config
.
tokenizer_revision
,
model_config
.
trust_remote_code
,
model_config
.
trust_remote_code
,
model_config
.
dtype
,
model_config
.
dtype
,
...
...
vllm/transformers_utils/config.py
View file @
9b9a10d6
...
@@ -2,9 +2,12 @@ from typing import Dict, Optional
...
@@ -2,9 +2,12 @@ from typing import Dict, Optional
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers
import
AutoConfig
,
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
from
vllm.transformers_utils.configs
import
(
ChatGLMConfig
,
DbrxConfig
,
JAISConfig
,
MPTConfig
,
RWConfig
)
JAISConfig
,
MPTConfig
,
RWConfig
)
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
PretrainedConfig
]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
PretrainedConfig
]
=
{
"chatglm"
:
ChatGLMConfig
,
"chatglm"
:
ChatGLMConfig
,
"dbrx"
:
DbrxConfig
,
"dbrx"
:
DbrxConfig
,
...
@@ -18,7 +21,8 @@ _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
...
@@ -18,7 +21,8 @@ _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
def
get_config
(
model
:
str
,
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
)
->
PretrainedConfig
:
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
)
->
PretrainedConfig
:
try
:
try
:
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
model
,
model
,
...
@@ -41,6 +45,10 @@ def get_config(model: str,
...
@@ -41,6 +45,10 @@ def get_config(model: str,
config
=
config_class
.
from_pretrained
(
model
,
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
revision
=
revision
,
code_revision
=
code_revision
)
code_revision
=
code_revision
)
if
rope_scaling
is
not
None
:
logger
.
info
(
"Updating rope_scaling from %r to %r"
,
getattr
(
config
,
"rope_scaling"
,
None
),
rope_scaling
)
config
.
update
({
"rope_scaling"
:
rope_scaling
})
return
config
return
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment