Unverified Commit 6dda14dc authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix TF doctests (#20159)

parent e0d7c831
...@@ -2399,8 +2399,8 @@ class TFGenerationMixin: ...@@ -2399,8 +2399,8 @@ class TFGenerationMixin:
... ) ... )
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor) >>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) ["Today is a beautiful day, and I'm so happy to be here. I'm so happy to"]
```""" ```"""
# 1. init greedy_search values # 1. init greedy_search values
...@@ -2634,6 +2634,7 @@ class TFGenerationMixin: ...@@ -2634,6 +2634,7 @@ class TFGenerationMixin:
Examples: Examples:
```python ```python
>>> import tensorflow as tf
>>> from transformers import ( >>> from transformers import (
... AutoTokenizer, ... AutoTokenizer,
... TFAutoModelForCausalLM, ... TFAutoModelForCausalLM,
...@@ -2666,9 +2667,11 @@ class TFGenerationMixin: ...@@ -2666,9 +2667,11 @@ class TFGenerationMixin:
... ] ... ]
... ) ... )
>>> tf.random.set_seed(0)
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today is a beautiful day, and it\'s all about the future."\n\nThe announcement comes three']
```""" ```"""
# 1. init greedy_search values # 1. init greedy_search values
...@@ -2927,15 +2930,15 @@ class TFGenerationMixin: ...@@ -2927,15 +2930,15 @@ class TFGenerationMixin:
>>> # lets run beam search using 3 beams >>> # lets run beam search using 3 beams
>>> num_beams = 3 >>> num_beams = 3
>>> # define decoder start token ids >>> # define decoder start token ids
>>> input_ids = tf.ones((num_beams, 1), dtype=tf.int64) >>> input_ids = tf.ones((1, num_beams, 1), dtype=tf.int32)
>>> input_ids = input_ids * model.config.decoder_start_token_id >>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments >>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = { >>> encoder_outputs = model.get_encoder()(encoder_input_ids, return_dict=True)
... "encoder_outputs": model.get_encoder()( >>> encoder_outputs.last_hidden_state = tf.repeat(
... tf.repeat(encoder_input_ids, num_beams, axis=0), return_dict=True ... tf.expand_dims(encoder_outputs.last_hidden_state, axis=0), num_beams, axis=1
... ) ... )
... } >>> model_kwargs = {"encoder_outputs": encoder_outputs}
>>> # instantiate logits processors >>> # instantiate logits processors
>>> logits_processor = TFLogitsProcessorList( >>> logits_processor = TFLogitsProcessorList(
...@@ -2944,7 +2947,8 @@ class TFGenerationMixin: ...@@ -2944,7 +2947,8 @@ class TFGenerationMixin:
>>> outputs = model.beam_search(input_ids, logits_processor=logits_processor, **model_kwargs) >>> outputs = model.beam_search(input_ids, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```""" ```"""
def flatten_beam_dim(tensor, batch_axis=0): def flatten_beam_dim(tensor, batch_axis=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment