Unverified Commit e737446e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Modeling` / `Mixtral`] Fix GC + PEFT issues with Mixtral (#28061)

fix for mistral
parent 1e209317
......@@ -1016,6 +1016,13 @@ class MixtralModel(MixtralPreTrainedModel):
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
......@@ -1058,13 +1065,6 @@ class MixtralModel(MixtralPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment