Unverified Commit c7f3abc2 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

introduce `logger.warning_once` and use it for grad checkpointing code (#21804)

* logger.warning_once

* style
parent f95f60c8
......@@ -444,7 +444,7 @@ class GitEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -853,7 +853,7 @@ class GPT2Model(GPT2PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -589,7 +589,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -653,7 +653,7 @@ class GPTJModel(GPTJPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -812,7 +812,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -479,7 +479,7 @@ class LayoutLMEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -2136,7 +2136,7 @@ class LEDDecoder(LEDPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1055,7 +1055,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting"
" `use_cache=False`..."
)
......
......@@ -1020,7 +1020,7 @@ class MarianDecoder(MarianPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -641,7 +641,7 @@ class MarkupLMEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1069,7 +1069,7 @@ class MBartDecoder(MBartPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -544,7 +544,7 @@ class MegatronBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1008,7 +1008,7 @@ class MT5Stack(MT5PreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1212,7 +1212,7 @@ class MvpDecoder(MvpPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -571,7 +571,7 @@ class NezhaEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -671,7 +671,7 @@ class OPTDecoder(OPTPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1070,7 +1070,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1311,7 +1311,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1048,7 +1048,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1572,7 +1572,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment