Unverified Commit 5fabd1e8 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Generation: fix handling of special tokens (#31254)

* fix special tokens in generatioon

* fix test

* add warning

* fix the check

* warn once

* fix
parent 7729b774
...@@ -1436,23 +1436,6 @@ class GenerationMixin: ...@@ -1436,23 +1436,6 @@ class GenerationMixin:
self._cache.reset() self._cache.reset()
return self._cache return self._cache
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
else:
return
def _supports_default_dynamic_cache(self) -> bool: def _supports_default_dynamic_cache(self) -> bool:
""" """
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
...@@ -1478,25 +1461,32 @@ class GenerationMixin: ...@@ -1478,25 +1461,32 @@ class GenerationMixin:
function). However, if called outside `generate`, consider creating a copy of `generation_config` first. function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
""" """
# Convert special tokens to tensors (if they exist) # Convert special tokens to tensors (if they exist either in kwargs or in self.config)
def _tensor_or_none(token, device=None): def _tensor_or_none(token_kwargs, token_self, device=None):
if device is None: if device is None:
device = self.device device = self.device
token = token_kwargs if token_kwargs is not None else token_self
if token is None or isinstance(token, torch.Tensor): if token is None or isinstance(token, torch.Tensor):
return token return token
return torch.tensor(token, device=device, dtype=torch.long) return torch.tensor(token, device=device, dtype=torch.long)
# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892) bos_token_id = _tensor_or_none(
if self.config.is_encoder_decoder: generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
generation_config.decoder_start_token_id = self._get_decoder_start_token_id( )
generation_config.decoder_start_token_id, generation_config.bos_token_id eos_token_id = _tensor_or_none(
) generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
)
pad_token_id = _tensor_or_none(
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
)
decoder_start_token_id = _tensor_or_none(
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
)
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device) # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device) if self.config.is_encoder_decoder:
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device) decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists). # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_id is not None and eos_token_id.ndim == 0: if eos_token_id is not None and eos_token_id.ndim == 0:
...@@ -1512,6 +1502,15 @@ class GenerationMixin: ...@@ -1512,6 +1502,15 @@ class GenerationMixin:
pad_token_id = eos_token_id[0] pad_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
"to obtain reliable results."
)
# Sanity checks/warnings # Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_id is None: if self.config.is_encoder_decoder and decoder_start_token_id is None:
raise ValueError( raise ValueError(
......
...@@ -161,6 +161,7 @@ class GenerationIntegrationTestsMixin: ...@@ -161,6 +161,7 @@ class GenerationIntegrationTestsMixin:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
model = model_cls.from_pretrained("distilbert/distilgpt2") model = model_cls.from_pretrained("distilbert/distilgpt2")
model.generation_config.eos_token_id = None
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt: if is_pt:
model = model.to(torch_device) model = model.to(torch_device)
...@@ -170,7 +171,6 @@ class GenerationIntegrationTestsMixin: ...@@ -170,7 +171,6 @@ class GenerationIntegrationTestsMixin:
input_ids=input_ids, input_ids=input_ids,
max_new_tokens=5, max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
) )
...@@ -197,6 +197,7 @@ class GenerationIntegrationTestsMixin: ...@@ -197,6 +197,7 @@ class GenerationIntegrationTestsMixin:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
model = model_cls.from_pretrained("distilbert/distilgpt2") model = model_cls.from_pretrained("distilbert/distilgpt2")
model.generation_config.eos_token_id = None
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt: if is_pt:
model = model.to(torch_device) model = model.to(torch_device)
...@@ -206,7 +207,6 @@ class GenerationIntegrationTestsMixin: ...@@ -206,7 +207,6 @@ class GenerationIntegrationTestsMixin:
input_ids=input_ids, input_ids=input_ids,
max_new_tokens=5, max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment