Unverified Commit beb24f2a authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: FLAX infers pad token in its absence and has functional example (#21009)

parent 480799f7
......@@ -305,10 +305,10 @@ class FlaxGenerationMixin:
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
>>> input_context = "The dog"
>>> # encode input context
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
>>> inputs = tokenizer(input_context, return_tensors="np")
>>> # generate candidates using sampling
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
```"""
# Validate the `.generate()` call
self._validate_model_class()
......@@ -323,6 +323,17 @@ class FlaxGenerationMixin:
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
if pad_token_id is None and eos_token_id is not None:
if model_kwargs.get("attention_mask") is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
......@@ -525,8 +536,8 @@ class FlaxGenerationMixin:
batch_size, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)
# per batch-item holding current token in loop.
......@@ -614,8 +625,8 @@ class FlaxGenerationMixin:
batch_size, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)
# per batch-item holding current token in loop.
......@@ -748,8 +759,8 @@ class FlaxGenerationMixin:
batch_size, num_beams, cur_len = input_ids.shape
eos_token_id = jnp.array(eos_token_id)
pad_token_id = jnp.array(pad_token_id)
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)
# per batch,beam-item holding current token in loop.
......
......@@ -702,11 +702,11 @@ class TFGenerationMixin:
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(
f"Setting `pad_token_id` to {generation_config.eos_token_id} (first `eos_token_id`) to generate"
" sequence"
)
generation_config.pad_token_id = generation_config.eos_token_id
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
use_xla = not tf.executing_eagerly()
if use_xla and not self.supports_xla_generation:
......
......@@ -13,6 +13,7 @@ docs/source/en/model_doc/tapex.mdx
docs/source/en/model_doc/donut.mdx
docs/source/en/model_doc/encoder-decoder.mdx
src/transformers/generation/configuration_utils.py
src/transformers/generation/flax_utils.py
src/transformers/generation/tf_utils.py
src/transformers/generation/utils.py
src/transformers/models/albert/configuration_albert.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment