Unverified Commit 7e662e6a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix model templates and use less than 119 chars (#9684)

* Fix model templates and use less than 119 chars

* Missing new line
parent 2ebbbf55
...@@ -997,7 +997,8 @@ class BartDecoder(BartPretrainedModel): ...@@ -997,7 +997,8 @@ class BartDecoder(BartPretrainedModel):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -544,7 +544,8 @@ class BertEncoder(nn.Module): ...@@ -544,7 +544,8 @@ class BertEncoder(nn.Module):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -959,7 +959,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -959,7 +959,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -959,7 +959,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -959,7 +959,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -541,7 +541,8 @@ class ElectraEncoder(nn.Module): ...@@ -541,7 +541,8 @@ class ElectraEncoder(nn.Module):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -721,7 +721,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -721,7 +721,8 @@ class GPT2Model(GPT2PreTrainedModel):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -471,7 +471,8 @@ class LayoutLMEncoder(nn.Module): ...@@ -471,7 +471,8 @@ class LayoutLMEncoder(nn.Module):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1924,7 +1924,8 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1924,7 +1924,8 @@ class LEDDecoder(LEDPreTrainedModel):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -962,7 +962,8 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -962,7 +962,8 @@ class MarianDecoder(MarianPreTrainedModel):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1006,7 +1006,8 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1006,7 +1006,8 @@ class MBartDecoder(MBartPreTrainedModel):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -970,7 +970,8 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -970,7 +970,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -484,7 +484,8 @@ class RobertaEncoder(nn.Module): ...@@ -484,7 +484,8 @@ class RobertaEncoder(nn.Module):
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -526,8 +526,16 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -526,8 +526,16 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions) return module(*inputs, past_key_value, output_attentions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment