Unverified Commit c9035e45 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix: The 'warn' method is deprecated (#11105)

* The 'warn' method is deprecated

* fix test
parent 247bed38
...@@ -484,7 +484,7 @@ class RobertaEncoder(nn.Module): ...@@ -484,7 +484,7 @@ class RobertaEncoder(nn.Module):
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
logger.warn( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..." "`use_cache=False`..."
) )
......
...@@ -1015,7 +1015,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -1015,7 +1015,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
logger.warn( logger.warning(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False use_cache = False
......
...@@ -111,7 +111,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned): ...@@ -111,7 +111,7 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
if not is_used: if not is_used:
unused_weights.append(name) unused_weights.append(name)
logger.warn(f"Unused weights: {unused_weights}") logger.warning(f"Unused weights: {unused_weights}")
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
......
...@@ -1140,7 +1140,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1140,7 +1140,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
) )
if inputs["lengths"] is not None: if inputs["lengths"] is not None:
logger.warn( logger.warning(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the " "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.", "attention mask instead.",
) )
......
...@@ -1232,7 +1232,7 @@ class XLMForMultipleChoice(XLMPreTrainedModel): ...@@ -1232,7 +1232,7 @@ class XLMForMultipleChoice(XLMPreTrainedModel):
) )
if lengths is not None: if lengths is not None:
logger.warn( logger.warning(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the " "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead." "attention mask instead."
) )
......
...@@ -142,7 +142,7 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -142,7 +142,7 @@ class ZeroShotClassificationPipeline(Pipeline):
""" """
if "multi_class" in kwargs and kwargs["multi_class"] is not None: if "multi_class" in kwargs and kwargs["multi_class"] is not None:
multi_label = kwargs.pop("multi_class") multi_label = kwargs.pop("multi_class")
logger.warn( logger.warning(
"The `multi_class` argument has been deprecated and renamed to `multi_label`. " "The `multi_class` argument has been deprecated and renamed to `multi_label`. "
"`multi_class` will be removed in a future version of Transformers." "`multi_class` will be removed in a future version of Transformers."
) )
......
...@@ -289,7 +289,7 @@ class CallbackHandler(TrainerCallback): ...@@ -289,7 +289,7 @@ class CallbackHandler(TrainerCallback):
self.eval_dataloader = None self.eval_dataloader = None
if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks): if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
logger.warn( logger.warning(
"The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n" "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
+ "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of" + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
+ "callbacks is\n:" + "callbacks is\n:"
...@@ -300,7 +300,7 @@ class CallbackHandler(TrainerCallback): ...@@ -300,7 +300,7 @@ class CallbackHandler(TrainerCallback):
cb = callback() if isinstance(callback, type) else callback cb = callback() if isinstance(callback, type) else callback
cb_class = callback if isinstance(callback, type) else callback.__class__ cb_class = callback if isinstance(callback, type) else callback.__class__
if cb_class in [c.__class__ for c in self.callbacks]: if cb_class in [c.__class__ for c in self.callbacks]:
logger.warn( logger.warning(
f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current" f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
+ "list of callbacks is\n:" + "list of callbacks is\n:"
+ self.callback_list + self.callback_list
......
...@@ -391,7 +391,7 @@ class DistributedTensorGatherer: ...@@ -391,7 +391,7 @@ class DistributedTensorGatherer:
if self._storage is None: if self._storage is None:
return return
if self._offsets[0] != self.process_length: if self._offsets[0] != self.process_length:
logger.warn("Not all data has been set. Are you sure you passed all values?") logger.warning("Not all data has been set. Are you sure you passed all values?")
return nested_truncate(self._storage, self.num_samples) return nested_truncate(self._storage, self.num_samples)
...@@ -589,7 +589,7 @@ def _get_learning_rate(self): ...@@ -589,7 +589,7 @@ def _get_learning_rate(self):
last_lr = self.lr_scheduler.get_last_lr()[0] last_lr = self.lr_scheduler.get_last_lr()[0]
except AssertionError as e: except AssertionError as e:
if "need to call step" in str(e): if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0") logger.warning("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0 last_lr = 0
else: else:
raise raise
......
...@@ -531,7 +531,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -531,7 +531,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
logger.warn( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..." "`use_cache=False`..."
) )
...@@ -2512,7 +2512,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2512,7 +2512,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
if getattr(self.config, "gradient_checkpointing", False) and self.training: if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") logger.warning("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
use_cache = False use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
......
...@@ -353,7 +353,7 @@ def main(): ...@@ -353,7 +353,7 @@ def main():
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
else: else:
logger.warn( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.", "\nIgnoring the model labels as a result.",
...@@ -362,7 +362,7 @@ def main(): ...@@ -362,7 +362,7 @@ def main():
label_to_id = {v: i for i, v in enumerate(label_list)} label_to_id = {v: i for i, v in enumerate(label_list)}
if data_args.max_seq_length > tokenizer.model_max_length: if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn( logger.warning(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
) )
......
...@@ -51,7 +51,7 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -51,7 +51,7 @@ class HfArgumentParserTest(unittest.TestCase):
# should be able to log warnings (if default settings weren't overridden by `pytest --log-level-all`) # should be able to log warnings (if default settings weren't overridden by `pytest --log-level-all`)
if level_origin <= logging.WARNING: if level_origin <= logging.WARNING:
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
logger.warn(msg) logger.warning(msg)
self.assertEqual(cl.out, msg + "\n") self.assertEqual(cl.out, msg + "\n")
# this is setting the level for all of `transformers.*` loggers # this is setting the level for all of `transformers.*` loggers
...@@ -59,7 +59,7 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -59,7 +59,7 @@ class HfArgumentParserTest(unittest.TestCase):
# should not be able to log warnings # should not be able to log warnings
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
logger.warn(msg) logger.warning(msg)
self.assertEqual(cl.out, "") self.assertEqual(cl.out, "")
# should be able to log warnings again # should be able to log warnings again
......
...@@ -234,7 +234,7 @@ class TrainerCallbackTest(unittest.TestCase): ...@@ -234,7 +234,7 @@ class TrainerCallbackTest(unittest.TestCase):
self.assertEqual(events, self.get_expected_events(trainer)) self.assertEqual(events, self.get_expected_events(trainer))
# warning should be emitted for duplicated callbacks # warning should be emitted for duplicated callbacks
with unittest.mock.patch("transformers.trainer_callback.logger.warn") as warn_mock: with unittest.mock.patch("transformers.trainer_callback.logger.warning") as warn_mock:
trainer = self.get_trainer( trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment