Unverified Commit 0bb1e885 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Make `max_model_len` configurable (#972)

parent d6545ad2
...@@ -38,6 +38,8 @@ class ModelConfig: ...@@ -38,6 +38,8 @@ class ModelConfig:
will use FP16 precision for FP32 and FP16 models, and BF16 precision will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models. for BF16 models.
seed: Random seed for reproducibility. seed: Random seed for reproducibility.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
""" """
def __init__( def __init__(
...@@ -50,6 +52,7 @@ class ModelConfig: ...@@ -50,6 +52,7 @@ class ModelConfig:
load_format: str, load_format: str,
dtype: str, dtype: str,
seed: int, seed: int,
max_model_len: Optional[int] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -63,6 +66,16 @@ class ModelConfig: ...@@ -63,6 +66,16 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self.max_model_len = None
if max_model_len is not None:
derived_max_model_len = self.get_max_model_len()
if max_model_len > derived_max_model_len:
logger.warning(
f"User-specified max_model_len ({max_model_len}) is "
f"greater than the derived max_model_len "
f"({derived_max_model_len}). Make sure the value is "
"correct and within the model context size.")
self.max_model_len = max_model_len
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
load_format = self.load_format.lower() load_format = self.load_format.lower()
...@@ -134,6 +147,8 @@ class ModelConfig: ...@@ -134,6 +147,8 @@ class ModelConfig:
return total_num_attention_heads // parallel_config.tensor_parallel_size return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int: def get_max_model_len(self) -> int:
if self.max_model_len is not None:
return self.max_model_len
max_model_len = float("inf") max_model_len = float("inf")
possible_keys = [ possible_keys = [
# OPT # OPT
......
...@@ -18,6 +18,7 @@ class EngineArgs: ...@@ -18,6 +18,7 @@ class EngineArgs:
load_format: str = 'auto' load_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
...@@ -89,6 +90,11 @@ class EngineArgs: ...@@ -89,6 +90,11 @@ class EngineArgs:
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', parser.add_argument('--worker-use-ray',
action='store_true', action='store_true',
...@@ -153,7 +159,7 @@ class EngineArgs: ...@@ -153,7 +159,7 @@ class EngineArgs:
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed) self.dtype, self.seed, self.max_model_len)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space) self.swap_space)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment