Unverified Commit 9b9a10d6 authored by sasha0552's avatar sasha0552 Committed by GitHub
Browse files

[Frontend] Dynamic RoPE scaling (#4638)

parent 99eff67b
...@@ -37,3 +37,57 @@ def test_get_sliding_window(): ...@@ -37,3 +37,57 @@ def test_get_sliding_window():
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
def test_rope_scaling():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
assert llama_model_config.max_model_len == 8192
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
)
assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert llama_model_config.max_model_len == 16384
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == LONGCHAT_ROPE_SCALING
assert longchat_model_config.max_model_len == 16384
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert longchat_model_config.max_model_len == 4096
...@@ -45,6 +45,9 @@ class ModelConfig: ...@@ -45,6 +45,9 @@ class ModelConfig:
code_revision: The specific revision to use for the model code on code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version. commit id. If unspecified, will use the default version.
rope_scaling: Dictionary containing the scaling configuration for the
RoPE embeddings. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use branch name, a tag name, or a commit id. If unspecified, will use
the default version. the default version.
...@@ -84,6 +87,7 @@ class ModelConfig: ...@@ -84,6 +87,7 @@ class ModelConfig:
seed: int, seed: int,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
...@@ -102,6 +106,7 @@ class ModelConfig: ...@@ -102,6 +106,7 @@ class ModelConfig:
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.code_revision = code_revision self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.tokenizer_revision = tokenizer_revision self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.quantization_param_path = quantization_param_path self.quantization_param_path = quantization_param_path
...@@ -116,7 +121,7 @@ class ModelConfig: ...@@ -116,7 +121,7 @@ class ModelConfig:
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision, self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision) code_revision, rope_scaling)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config, self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
......
import argparse import argparse
import dataclasses import dataclasses
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -49,6 +50,7 @@ class EngineArgs: ...@@ -49,6 +50,7 @@ class EngineArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
...@@ -330,6 +332,11 @@ class EngineArgs: ...@@ -330,6 +332,11 @@ class EngineArgs:
'None, we assume the model weights are not ' 'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data ' 'quantized and use `dtype` to determine the data '
'type of the weights.') 'type of the weights.')
parser.add_argument('--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, {"type":"dynamic","factor":2.0}')
parser.add_argument('--enforce-eager', parser.add_argument('--enforce-eager',
action='store_true', action='store_true',
help='Always use eager-mode PyTorch. If False, ' help='Always use eager-mode PyTorch. If False, '
...@@ -548,11 +555,12 @@ class EngineArgs: ...@@ -548,11 +555,12 @@ class EngineArgs:
model_config = ModelConfig( model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode, self.model, self.tokenizer, self.tokenizer_mode,
self.trust_remote_code, self.dtype, self.seed, self.revision, self.trust_remote_code, self.dtype, self.seed, self.revision,
self.code_revision, self.tokenizer_revision, self.max_model_len, self.code_revision, self.rope_scaling, self.tokenizer_revision,
self.quantization, self.quantization_param_path, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture, self.quantization_param_path, self.enforce_eager,
self.max_seq_len_to_capture, self.max_logprobs, self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.skip_tokenizer_init, self.served_model_name) self.max_logprobs, self.skip_tokenizer_init,
self.served_model_name)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
......
...@@ -104,10 +104,11 @@ class LLMEngine: ...@@ -104,10 +104,11 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, " "model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " "rope_scaling=%r, tokenizer_revision=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " "disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)", "decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__, vllm.__version__,
...@@ -117,6 +118,7 @@ class LLMEngine: ...@@ -117,6 +118,7 @@ class LLMEngine:
model_config.skip_tokenizer_init, model_config.skip_tokenizer_init,
model_config.tokenizer_mode, model_config.tokenizer_mode,
model_config.revision, model_config.revision,
model_config.rope_scaling,
model_config.tokenizer_revision, model_config.tokenizer_revision,
model_config.trust_remote_code, model_config.trust_remote_code,
model_config.dtype, model_config.dtype,
......
...@@ -2,9 +2,12 @@ from typing import Dict, Optional ...@@ -2,9 +2,12 @@ from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig) JAISConfig, MPTConfig, RWConfig)
logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
...@@ -18,7 +21,8 @@ _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { ...@@ -18,7 +21,8 @@ _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
def get_config(model: str, def get_config(model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig: code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, model,
...@@ -41,6 +45,10 @@ def get_config(model: str, ...@@ -41,6 +45,10 @@ def get_config(model: str,
config = config_class.from_pretrained(model, config = config_class.from_pretrained(model,
revision=revision, revision=revision,
code_revision=code_revision) code_revision=code_revision)
if rope_scaling is not None:
logger.info("Updating rope_scaling from %r to %r",
getattr(config, "rope_scaling", None), rope_scaling)
config.update({"rope_scaling": rope_scaling})
return config return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment