"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2d43094ffc9b1ee377651c6c8a358c81f0c96005"
Unverified Commit 60597219 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

check user-specified model_max_len with hf derived max_model_len (#1778)

parent fc82f5a7
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import logging
import os
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Optional from typing import Optional
...@@ -20,6 +22,8 @@ from transformers import PretrainedConfig ...@@ -20,6 +22,8 @@ from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.hf_transformers_utils import get_config, get_context_length
logger = logging.getLogger(__name__)
class AttentionArch(IntEnum): class AttentionArch(IntEnum):
MLA = auto() MLA = auto()
...@@ -46,10 +50,29 @@ class ModelConfig: ...@@ -46,10 +50,29 @@ class ModelConfig:
model_override_args=model_override_args, model_override_args=model_override_args,
) )
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
derived_context_len = get_context_length(self.hf_text_config)
allow_long_context = os.environ.get(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
)
if context_length is not None: if context_length is not None:
self.context_len = context_length if context_length > derived_context_len:
if allow_long_context:
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
)
self.context_len = context_length
else:
raise ValueError(
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
)
else:
self.context_len = context_length
else: else:
self.context_len = get_context_length(self.hf_text_config) self.context_len = derived_context_len
# Unify the config keys for hf_text_config # Unify the config keys for hf_text_config
self.head_dim = getattr( self.head_dim = getattr(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment