"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "98364ea74f35b16afdd6541b8dc6e205131b5e61"
Unverified Commit e0182f3b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

RoPE: relaxed rope validation (#32182)

* relaxed rope check

* lets also accept rope_type=None, defaulting to the original implementation

* type and rope_type can coexist
parent 165116bc
...@@ -354,6 +354,11 @@ ROPE_INIT_FUNCTIONS = { ...@@ -354,6 +354,11 @@ ROPE_INIT_FUNCTIONS = {
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys""" """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's gracefully handle it
if "rope_type" not in received_keys and "type" in received_keys:
received_keys -= {"type"}
received_keys.add("rope_type")
missing_keys = required_keys - received_keys missing_keys = required_keys - received_keys
if missing_keys: if missing_keys:
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
...@@ -361,14 +366,14 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, ...@@ -361,14 +366,14 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
if optional_keys is not None: if optional_keys is not None:
unused_keys = received_keys - required_keys - optional_keys unused_keys = received_keys - required_keys - optional_keys
else: else:
unused_keys = received_keys - received_keys unused_keys = received_keys - required_keys
if unused_keys: if unused_keys:
raise KeyError(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
def _validate_default_rope_parameters(config: PretrainedConfig): def _validate_default_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"} required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys)
...@@ -376,19 +381,19 @@ def _validate_default_rope_parameters(config: PretrainedConfig): ...@@ -376,19 +381,19 @@ def _validate_default_rope_parameters(config: PretrainedConfig):
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings` # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"} optional_keys = {"original_max_position_embeddings"}
...@@ -397,12 +402,12 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): ...@@ -397,12 +402,12 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig): def _validate_yarn_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"} optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
...@@ -410,22 +415,22 @@ def _validate_yarn_parameters(config: PretrainedConfig): ...@@ -410,22 +415,22 @@ def _validate_yarn_parameters(config: PretrainedConfig):
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor") attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
raise ValueError( logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
) )
beta_fast = rope_scaling.get("beta_fast") beta_fast = rope_scaling.get("beta_fast")
if beta_fast is not None and not isinstance(beta_fast, float): if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
beta_slow = rope_scaling.get("beta_slow") beta_slow = rope_scaling.get("beta_slow")
if beta_slow is not None and not isinstance(beta_slow, float): if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
if (beta_fast or 32) < (beta_slow or 1): if (beta_fast or 32) < (beta_slow or 1):
raise ValueError( logger.warning(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
) )
...@@ -433,7 +438,7 @@ def _validate_yarn_parameters(config: PretrainedConfig): ...@@ -433,7 +438,7 @@ def _validate_yarn_parameters(config: PretrainedConfig):
def _validate_longrope_parameters(config: PretrainedConfig): def _validate_longrope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"} required_keys = {"rope_type", "short_factor", "long_factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings` # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
...@@ -445,15 +450,15 @@ def _validate_longrope_parameters(config: PretrainedConfig): ...@@ -445,15 +450,15 @@ def _validate_longrope_parameters(config: PretrainedConfig):
short_factor = rope_scaling.get("short_factor") short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
raise ValueError(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
if not len(short_factor) == dim // 2: if not len(short_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
long_factor = rope_scaling.get("long_factor") long_factor = rope_scaling.get("long_factor")
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
raise ValueError(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
if not len(long_factor) == dim // 2: if not len(long_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
...@@ -468,48 +473,48 @@ def _validate_longrope_parameters(config: PretrainedConfig): ...@@ -468,48 +473,48 @@ def _validate_longrope_parameters(config: PretrainedConfig):
else: else:
factor = rope_scaling.get("factor") factor = rope_scaling.get("factor")
if factor is None: if factor is None:
raise ValueError("Missing required keys in `rope_scaling`: 'factor'") logger.warning("Missing required keys in `rope_scaling`: 'factor'")
elif not isinstance(factor, float) or factor < 1.0: elif not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor") attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError( logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
) )
def _validate_llama3_parameters(config: PretrainedConfig): def _validate_llama3_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"] rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
low_freq_factor = rope_scaling["low_freq_factor"] low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"]
if low_freq_factor is None or not isinstance(low_freq_factor, float): if low_freq_factor is None or not isinstance(low_freq_factor, float):
raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float): if high_freq_factor is None or not isinstance(high_freq_factor, float):
raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor: if high_freq_factor < low_freq_factor:
raise ValueError( logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}" f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
) )
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
raise ValueError( logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got " "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
f"{original_max_position_embeddings}" f"{original_max_position_embeddings}"
) )
if original_max_position_embeddings >= config.max_position_embeddings: if original_max_position_embeddings >= config.max_position_embeddings:
raise ValueError( logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
) )
...@@ -534,17 +539,12 @@ def rope_config_validation(config: PretrainedConfig): ...@@ -534,17 +539,12 @@ def rope_config_validation(config: PretrainedConfig):
if rope_scaling is None: if rope_scaling is None:
return return
possible_rope_types = set(ROPE_INIT_FUNCTIONS.keys()) # BC: "rope_type" was originally "type"
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
if rope_type is None:
raise ValueError(
f"rope_scaling must contain a non-None 'rope_type' field. Possible options are {possible_rope_types}"
)
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None: if validation_fn is not None:
validation_fn(config) validation_fn(config)
else: else:
raise ValueError( logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
) )
...@@ -189,6 +189,9 @@ class LlamaConfig(PretrainedConfig): ...@@ -189,6 +189,9 @@ class LlamaConfig(PretrainedConfig):
self.mlp_bias = mlp_bias self.mlp_bias = mlp_bias
# Validate the correctness of rotary position embeddings parameters # Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self) rope_config_validation(self)
super().__init__( super().__init__(
......
...@@ -526,6 +526,60 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -526,6 +526,60 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
torch.testing.assert_close(old_cos_long, new_cos_long) torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long) torch.testing.assert_close(old_sin_long, new_sin_long)
def test_model_loading_old_rope_configs(self):
def _reinitialize_config(base_config, new_kwargs):
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
# steps.
base_config_dict = base_config.to_dict()
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
return new_config
# from untouched config -> ✅
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
original_model = LlamaForCausalLM(base_config).to(torch_device)
original_model(**model_inputs)
# from a config with the expected rope configuration -> ✅
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
config = _reinitialize_config(
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
)
self.assertTrue(config.rope_scaling["type"] == "linear")
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("factor field", logs.output[0])
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
)
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("Unrecognized keys", logs.output[0])
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@require_bitsandbytes @require_bitsandbytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment