Unverified Commit 8be6432b authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[CI Failure] Fix NomicBert max_model_len validation (#31662)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 43e3f8e4
...@@ -1264,12 +1264,6 @@ class VllmConfig: ...@@ -1264,12 +1264,6 @@ class VllmConfig:
computed_compile_ranges_split_points computed_compile_ranges_split_points
) )
def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config
model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len
def try_verify_and_update_config(self): def try_verify_and_update_config(self):
if self.model_config is None: if self.model_config is None:
return return
......
...@@ -113,8 +113,8 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig): ...@@ -113,8 +113,8 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
class NomicBertModelConfig(VerifyAndUpdateConfig): class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = vllm_config.model_config.hf_config config = model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig" assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"] assert config.activation_function in ["swiglu", "gelu"]
...@@ -137,6 +137,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): ...@@ -137,6 +137,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
config.intermediate_size = config.n_inner config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer config.num_hidden_layers = config.n_layer
model_config.model_arch_config.hidden_size = config.hidden_size
model_config.model_arch_config.total_num_hidden_layers = (
config.num_hidden_layers
)
head_dim = config.hidden_size // config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads
max_trained_positions = getattr(config, "max_trained_positions", 2048) max_trained_positions = getattr(config, "max_trained_positions", 2048)
...@@ -153,42 +157,43 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): ...@@ -153,42 +157,43 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
# The context extension uses vllm style rope_theta and rope_parameters. # The context extension uses vllm style rope_theta and rope_parameters.
# See #17785 #18755 # See #17785 #18755
if ( if (
not vllm_config.model_config.hf_overrides not model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None and model_config.original_max_model_len is None
): ):
# Default # Default
# Reset max_model_len to max_trained_positions. # Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512 # nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json. # by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len max_model_len_before = model_config.max_model_len
max_model_len = min( max_model_len = min(model_config.max_model_len, max_trained_positions)
vllm_config.model_config.max_model_len, max_trained_positions
)
vllm_config.recalculate_max_model_len(max_model_len) model_config.max_model_len = model_config.get_and_verify_max_len(
logger.warning( max_model_len
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before,
vllm_config.model_config.max_model_len,
) )
if model_config.max_model_len != max_model_len_before:
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before,
model_config.max_model_len,
)
else: else:
# We need to re-verify max_model_len to avoid lengths # We need to re-verify max_model_len to avoid lengths
# greater than position_embedding. # greater than position_embedding.
model_config = vllm_config.model_config
hf_text_config = model_config.hf_text_config hf_text_config = model_config.hf_text_config
if isinstance(model_config.hf_overrides, dict): if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw # hf_overrides_kw
max_model_len = model_config.hf_overrides.get( max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len "max_model_len", model_config.max_model_len
) )
else: else:
# hf_overrides_fn # hf_overrides_fn
# This might be overridden by sentence_bert_config.json. # This might be overridden by sentence_bert_config.json.
max_model_len = vllm_config.model_config.max_model_len max_model_len = model_config.max_model_len
# reset hf_text_config for recalculate_max_model_len. # reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"): if hasattr(hf_text_config, "max_model_len"):
...@@ -196,13 +201,21 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): ...@@ -196,13 +201,21 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
hf_text_config.max_position_embeddings = max_trained_positions hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"] hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
# Update the cached derived_max_model_len to enforce the limit
model_config.model_arch_config.derived_max_model_len_and_key = (
float(max_trained_positions),
"max_position_embeddings",
)
# The priority of sentence_bert_config.json is higher # The priority of sentence_bert_config.json is higher
# than max_position_embeddings # than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config) encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None) encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config model_config.encoder_config = encoder_config
vllm_config.recalculate_max_model_len(max_model_len) model_config.max_model_len = model_config.get_and_verify_max_len(
max_model_len
)
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment