Unverified Commit 6a5472a8 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Force use_cache to be False in PyTorch (#15385)



* use_cache = False for PT models if labels is passed

* Fix for BigBirdPegasusForConditionalGeneration

* add warning if users specify use_cache=True

* Use logger.warning instead of warnings.warn
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 0acd84f7
...@@ -1318,6 +1318,9 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1318,6 +1318,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
......
...@@ -2513,6 +2513,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2513,6 +2513,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None and decoder_inputs_embeds is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
......
...@@ -1287,6 +1287,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1287,6 +1287,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
......
...@@ -1258,6 +1258,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1258,6 +1258,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
......
...@@ -2366,6 +2366,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2366,6 +2366,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
......
...@@ -1291,6 +1291,9 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1291,6 +1291,9 @@ class MarianMTModel(MarianPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
......
...@@ -1314,6 +1314,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1314,6 +1314,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
......
...@@ -1381,6 +1381,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1381,6 +1381,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
......
...@@ -2832,6 +2832,9 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte ...@@ -2832,6 +2832,9 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment