Unverified Commit 8d2534c4 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

let's not warn when someone is running a forward (#32176)

* let's not warn when someone is running a foward without cache + self.training

* more models

* fixup
parent e0182f3b
...@@ -769,7 +769,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -769,7 +769,9 @@ class CohereModel(CoherePreTrainedModel):
past_seen_tokens = 0 past_seen_tokens = 0
return_legacy_cache = False return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
......
...@@ -1005,7 +1005,9 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1005,7 +1005,9 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
return_legacy_cache = False return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -474,7 +474,9 @@ class GemmaModel(LlamaModel): ...@@ -474,7 +474,9 @@ class GemmaModel(LlamaModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False # noqa: F841 return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841 return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
......
...@@ -770,7 +770,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -770,7 +770,9 @@ class GemmaModel(GemmaPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False # noqa: F841 return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841 return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
...@@ -795,7 +797,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -795,7 +797,9 @@ class GemmaModel(GemmaPreTrainedModel):
# See https://github.com/huggingface/transformers/pull/29402 # See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer hidden_states = hidden_states * normalizer
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -978,7 +978,9 @@ class JetMoeModel(JetMoePreTrainedModel): ...@@ -978,7 +978,9 @@ class JetMoeModel(JetMoePreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
......
...@@ -894,7 +894,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -894,7 +894,9 @@ class LlamaModel(LlamaPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -758,7 +758,7 @@ class MistralModel(MistralPreTrainedModel): ...@@ -758,7 +758,7 @@ class MistralModel(MistralPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return_legacy_cache = True return_legacy_cache = True
logger.warning_once( logger.warning_once(
......
...@@ -960,7 +960,7 @@ class MixtralModel(MixtralPreTrainedModel): ...@@ -960,7 +960,7 @@ class MixtralModel(MixtralPreTrainedModel):
use_cache = False use_cache = False
use_legacy_cache = False use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -811,7 +811,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -811,7 +811,9 @@ class OlmoModel(OlmoPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -626,7 +626,7 @@ class PersimmonModel(PersimmonPreTrainedModel): ...@@ -626,7 +626,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
use_cache = False use_cache = False
use_legacy_cache = False use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -909,7 +909,7 @@ class PhiModel(PhiPreTrainedModel): ...@@ -909,7 +909,7 @@ class PhiModel(PhiPreTrainedModel):
use_cache = False use_cache = False
use_legacy_cache = False use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -950,7 +950,7 @@ class Phi3Model(Phi3PreTrainedModel): ...@@ -950,7 +950,7 @@ class Phi3Model(Phi3PreTrainedModel):
use_cache = False use_cache = False
use_legacy_cache = False use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -808,7 +808,7 @@ class Qwen2Model(Qwen2PreTrainedModel): ...@@ -808,7 +808,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
use_cache = False use_cache = False
use_legacy_cache = False use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -970,7 +970,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): ...@@ -970,7 +970,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
use_cache = False use_cache = False
use_legacy_cache = False use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -902,7 +902,7 @@ class StableLmModel(StableLmPreTrainedModel): ...@@ -902,7 +902,7 @@ class StableLmModel(StableLmPreTrainedModel):
use_cache = False use_cache = False
use_legacy_cache = False use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
...@@ -784,7 +784,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): ...@@ -784,7 +784,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
use_cache = False use_cache = False
use_legacy_cache = False use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once( logger.warning_once(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment