Unverified Commit 146f6134 authored by Ran Chen's avatar Ran Chen Committed by GitHub
Browse files

Fix incorrect context length for llama3.2-11b (#1873)

parent 660ecb73
...@@ -88,19 +88,23 @@ CONTEXT_LENGTH_KEYS = [ ...@@ -88,19 +88,23 @@ CONTEXT_LENGTH_KEYS = [
def get_context_length(config): def get_context_length(config):
"""Get the context length of a model from a huggingface model configs.""" """Get the context length of a model from a huggingface model configs.
rope_scaling = getattr(config, "rope_scaling", None) And here the config should be text_config part if the model is a multimodal
LLM.
"""
text_config = getattr(config, "text_config", config)
rope_scaling = getattr(text_config, "rope_scaling", None)
if rope_scaling: if rope_scaling:
rope_scaling_factor = config.rope_scaling.get("factor", 1) rope_scaling_factor = rope_scaling.get("factor", 1)
if "original_max_position_embeddings" in rope_scaling: if "original_max_position_embeddings" in rope_scaling:
rope_scaling_factor = 1 rope_scaling_factor = 1
if config.rope_scaling.get("rope_type", None) == "llama3": if rope_scaling.get("rope_type", None) == "llama3":
rope_scaling_factor = 1 rope_scaling_factor = 1
else: else:
rope_scaling_factor = 1 rope_scaling_factor = 1
for key in CONTEXT_LENGTH_KEYS: for key in CONTEXT_LENGTH_KEYS:
val = getattr(config, key, None) val = getattr(text_config, key, None)
if val is not None: if val is not None:
return int(rope_scaling_factor * val) return int(rope_scaling_factor * val)
return 2048 return 2048
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment