Unverified Commit 4ee4826e authored by 汪志鹏's avatar 汪志鹏 Committed by GitHub
Browse files

[BugFix] Correct max_model_len derivation from config.json for Mistral format (#17937)


Signed-off-by: default avatar汪志鹏 <wangzhipeng628@gmail.com>
Co-authored-by: default avatartracelogfb <48808670+tracelogfb@users.noreply.github.com>
Co-authored-by: default avatarStephen Chen <tracelog@meta.com>
parent 60017dc8
......@@ -686,9 +686,24 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
config_dict["hidden_act"] = config_dict.get("activation", "silu")
config_dict["tie_word_embeddings"] = config_dict.get(
"tie_embeddings", False)
config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000)
if config_dict.get("max_position_embeddings") is None:
max_position_embeddings = 128_000
try:
trust_remote_code_val = kwargs.get("trust_remote_code", False)
hf_config = get_config(model=model,
trust_remote_code=trust_remote_code_val,
revision=revision,
config_format=ConfigFormat.HF)
if hf_value := hf_config.get_text_config().max_position_embeddings:
max_position_embeddings = hf_value
except Exception as e:
logger.warning(
"The params.json file is missing 'max_position_embeddings'"
" and could not get a value from the HF config."
" Defaulting to 128000",
exc_info=e)
config_dict["max_position_embeddings"] = max_position_embeddings
if config_dict.get("quantization") is not None:
quantization = config_dict.get("quantization", {})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment