Unverified Commit a3e8611d authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[Bugfix] Limit the default value of `max_model_len` when it is not specified by users (#27556)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
parent 7c2bdb83
...@@ -2112,20 +2112,13 @@ def _get_and_verify_max_len( ...@@ -2112,20 +2112,13 @@ def _get_and_verify_max_len(
if encoder_config and "max_seq_length" in encoder_config: if encoder_config and "max_seq_length" in encoder_config:
derived_max_model_len = encoder_config["max_seq_length"] derived_max_model_len = encoder_config["max_seq_length"]
# If the user specified a max length, make sure it is smaller than the # If the user didn't specify `max_model_len`, then use that derived from
# derived length from the HF model config. # the model config as a default value.
if max_model_len is None: if max_model_len is None:
max_model_len = int(derived_max_model_len) max_model_len = int(derived_max_model_len)
if current_platform.is_tpu(): max_model_len = current_platform.check_max_model_len(max_model_len)
logger.warning( # If the user specified a max length, make sure it is smaller than the
"--max-model-len is not specified, " # derived length from the HF model config.
"it's currently using model's default length %s, "
"which might be too large."
"Please input with --max-model-len based on your "
"request input length and output length, to avoid "
"unnecessary degradation.",
max_model_len,
)
elif max_model_len > derived_max_model_len: elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length # Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input # that will be bigger than derived_max_model_len. We compare user input
......
...@@ -608,6 +608,13 @@ class Platform: ...@@ -608,6 +608,13 @@ class Platform:
""" """
return None return None
@classmethod
def check_max_model_len(cls, max_model_len: int) -> int:
"""
Check max_model_len for the current platform.
"""
return max_model_len
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
......
...@@ -251,6 +251,22 @@ class TpuPlatform(Platform): ...@@ -251,6 +251,22 @@ class TpuPlatform(Platform):
def use_sync_weight_loader(cls) -> bool: def use_sync_weight_loader(cls) -> bool:
return True return True
@classmethod
def check_max_model_len(cls, max_model_len: int) -> int:
"""
Check max_model_len for the current platform.
"""
logger.warning(
"--max-model-len is not specified, "
"it's currently using model's default length %d, "
"which might be too large."
"Please input with --max-model-len based on your "
"request input length and output length, to avoid "
"unnecessary degradation.",
max_model_len,
)
return max_model_len
try: try:
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment