Unverified Commit c7f3abc2 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

introduce `logger.warning_once` and use it for grad checkpointing code (#21804)

* logger.warning_once

* style
parent f95f60c8
...@@ -444,7 +444,7 @@ class GitEncoder(nn.Module): ...@@ -444,7 +444,7 @@ class GitEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -853,7 +853,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -853,7 +853,7 @@ class GPT2Model(GPT2PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -589,7 +589,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -589,7 +589,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -653,7 +653,7 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -653,7 +653,7 @@ class GPTJModel(GPTJPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -812,7 +812,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): ...@@ -812,7 +812,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -479,7 +479,7 @@ class LayoutLMEncoder(nn.Module): ...@@ -479,7 +479,7 @@ class LayoutLMEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -2136,7 +2136,7 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2136,7 +2136,7 @@ class LEDDecoder(LEDPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1055,7 +1055,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1055,7 +1055,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting" "`use_cache=True` is incompatible with gradient checkpointing. Setting"
" `use_cache=False`..." " `use_cache=False`..."
) )
......
...@@ -1020,7 +1020,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1020,7 +1020,7 @@ class MarianDecoder(MarianPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -641,7 +641,7 @@ class MarkupLMEncoder(nn.Module): ...@@ -641,7 +641,7 @@ class MarkupLMEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1069,7 +1069,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1069,7 +1069,7 @@ class MBartDecoder(MBartPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -544,7 +544,7 @@ class MegatronBertEncoder(nn.Module): ...@@ -544,7 +544,7 @@ class MegatronBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1008,7 +1008,7 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -1008,7 +1008,7 @@ class MT5Stack(MT5PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1212,7 +1212,7 @@ class MvpDecoder(MvpPreTrainedModel): ...@@ -1212,7 +1212,7 @@ class MvpDecoder(MvpPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -571,7 +571,7 @@ class NezhaEncoder(nn.Module): ...@@ -571,7 +571,7 @@ class NezhaEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -671,7 +671,7 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -671,7 +671,7 @@ class OPTDecoder(OPTPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1070,7 +1070,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1070,7 +1070,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1311,7 +1311,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1311,7 +1311,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1048,7 +1048,7 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -1048,7 +1048,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1572,7 +1572,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1572,7 +1572,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment