Unverified Commit 7e662e6a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix model templates and use less than 119 chars (#9684)

* Fix model templates and use less than 119 chars

* Missing new line
parent 2ebbbf55
......@@ -997,7 +997,8 @@ class BartDecoder(BartPretrainedModel):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -544,7 +544,8 @@ class BertEncoder(nn.Module):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -959,7 +959,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -959,7 +959,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -541,7 +541,8 @@ class ElectraEncoder(nn.Module):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -721,7 +721,8 @@ class GPT2Model(GPT2PreTrainedModel):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -471,7 +471,8 @@ class LayoutLMEncoder(nn.Module):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -1924,7 +1924,8 @@ class LEDDecoder(LEDPreTrainedModel):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -962,7 +962,8 @@ class MarianDecoder(MarianPreTrainedModel):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -1006,7 +1006,8 @@ class MBartDecoder(MBartPreTrainedModel):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -970,7 +970,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -484,7 +484,8 @@ class RobertaEncoder(nn.Module):
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
......
......@@ -526,8 +526,16 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment