Unverified Commit beb88994 authored by Andy Lo's avatar Andy Lo Committed by GitHub
Browse files

Fix mistral sliding window parsing (#33521)


Signed-off-by: default avatarAndy Lo <andy@mistral.ai>
parent ce88756b
...@@ -225,19 +225,6 @@ class MistralConfigParser(ConfigParserBase): ...@@ -225,19 +225,6 @@ class MistralConfigParser(ConfigParserBase):
config = adapt_config_dict(config_dict, defaults=hf_config_dict) config = adapt_config_dict(config_dict, defaults=hf_config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if (sliding_window := getattr(config, "sliding_window", None)) and isinstance(
sliding_window, list
):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
return config_dict, config return config_dict, config
......
...@@ -14,6 +14,7 @@ def adapt_config_dict( ...@@ -14,6 +14,7 @@ def adapt_config_dict(
defaults: dict[str, Any], defaults: dict[str, Any],
) -> PretrainedConfig: ) -> PretrainedConfig:
config_dict = _remap_general_mistral_args(config_dict) config_dict = _remap_general_mistral_args(config_dict)
config_dict = _remap_mistral_sliding_window(config_dict)
if bool(config_dict.get("quantization")): if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict) config_dict = _remap_mistral_quantization_args(config_dict)
...@@ -161,6 +162,29 @@ def _remap_general_mistral_args(config: dict) -> dict: ...@@ -161,6 +162,29 @@ def _remap_general_mistral_args(config: dict) -> dict:
return config return config
def _remap_mistral_sliding_window(config: dict) -> dict:
# Remap sliding_window (list) -> layer_types (list) + sliding window (int)
# for HF compatibility
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if sliding_window := config.get("sliding_window"):
if isinstance(sliding_window, list):
pattern_repeats = config["num_hidden_layers"] // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config["layer_types"] = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
assert len(set(sliding_window) - {None}) <= 1, sliding_window
config["sliding_window"] = next(filter(None, sliding_window), None)
elif isinstance(sliding_window, int) and config.get("layer_types") is None:
config["layer_types"] = ["sliding_attention"] * config["num_hidden_layers"]
else:
raise ValueError(f"Unsupported sliding_window type: {sliding_window}")
return config
def _remap_mistral_quantization_args(config: dict) -> dict: def _remap_mistral_quantization_args(config: dict) -> dict:
if config.get("quantization"): if config.get("quantization"):
quantization = config.pop("quantization", {}) quantization = config.pop("quantization", {})
...@@ -195,14 +219,6 @@ def _remap_mistral_audio_args(config: dict) -> dict: ...@@ -195,14 +219,6 @@ def _remap_mistral_audio_args(config: dict) -> dict:
else: else:
block_pool_size = 1 block_pool_size = 1
_maybe_sliding_window = encoder_args.get("ragged_attention", None)
if _maybe_sliding_window is None:
sliding_window = None
elif _maybe_sliding_window.isdigit():
sliding_window = int(_maybe_sliding_window)
else:
raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")
architecture = ( architecture = (
"VoxtralRealtimeGeneration" "VoxtralRealtimeGeneration"
if encoder_args.get("causal") if encoder_args.get("causal")
...@@ -229,7 +245,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: ...@@ -229,7 +245,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
max_source_positions=encoder_args["max_source_positions"], max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default is_encoder_decoder=False, # Override WhisperConfig default
is_causal=encoder_args.get("causal", False), is_causal=encoder_args.get("causal", False),
sliding_window=sliding_window, sliding_window=encoder_args.get("sliding_window", None),
block_pool_size=block_pool_size, block_pool_size=block_pool_size,
pos_embed=encoder_args.get("pos_embed", "sinusoidal"), pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
# only needed for RoPE # only needed for RoPE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment