Unverified Commit c7f3abc2 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

introduce `logger.warning_once` and use it for grad checkpointing code (#21804)

* logger.warning_once

* style
parent f95f60c8
...@@ -575,7 +575,7 @@ class QDQBertEncoder(nn.Module): ...@@ -575,7 +575,7 @@ class QDQBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -578,7 +578,7 @@ class RealmEncoder(nn.Module): ...@@ -578,7 +578,7 @@ class RealmEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -536,7 +536,7 @@ class RemBertEncoder(nn.Module): ...@@ -536,7 +536,7 @@ class RemBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -502,7 +502,7 @@ class RobertaEncoder(nn.Module): ...@@ -502,7 +502,7 @@ class RobertaEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -504,7 +504,7 @@ class RobertaPreLayerNormEncoder(nn.Module): ...@@ -504,7 +504,7 @@ class RobertaPreLayerNormEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -636,7 +636,7 @@ class RoCBertEncoder(nn.Module): ...@@ -636,7 +636,7 @@ class RoCBertEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -573,7 +573,7 @@ class RoFormerEncoder(nn.Module): ...@@ -573,7 +573,7 @@ class RoFormerEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1692,7 +1692,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): ...@@ -1692,7 +1692,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -451,7 +451,7 @@ class SplinterEncoder(nn.Module): ...@@ -451,7 +451,7 @@ class SplinterEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1057,7 +1057,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): ...@@ -1057,7 +1057,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1037,7 +1037,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -1037,7 +1037,7 @@ class T5Stack(T5PreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1471,7 +1471,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): ...@@ -1471,7 +1471,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -543,7 +543,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel): ...@@ -543,7 +543,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1595,7 +1595,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1595,7 +1595,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -503,7 +503,7 @@ class XLMRobertaEncoder(nn.Module): ...@@ -503,7 +503,7 @@ class XLMRobertaEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -492,7 +492,7 @@ class XLMRobertaXLEncoder(nn.Module): ...@@ -492,7 +492,7 @@ class XLMRobertaXLEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -566,7 +566,7 @@ class XmodEncoder(nn.Module): ...@@ -566,7 +566,7 @@ class XmodEncoder(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
""" Logging utilities.""" """ Logging utilities."""
import functools
import logging import logging
import os import os
import sys import sys
...@@ -281,6 +283,21 @@ def warning_advice(self, *args, **kwargs): ...@@ -281,6 +283,21 @@ def warning_advice(self, *args, **kwargs):
logging.Logger.warning_advice = warning_advice logging.Logger.warning_advice = warning_advice
@functools.lru_cache(None)
def warning_once(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.warning(*args, **kwargs)
logging.Logger.warning_once = warning_once
class EmptyTqdm: class EmptyTqdm:
"""Dummy tqdm which doesn't do anything.""" """Dummy tqdm which doesn't do anything."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment