"sgl-router/src/vscode:/vscode.git/clone" did not exist on "e5281f84d5bcbdfd6d7790c5ab450429e11291a6"
Unverified Commit 6dda14dc authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix TF doctests (#20159)

parent e0d7c831
......@@ -2399,8 +2399,8 @@ class TFGenerationMixin:
... )
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["Today is a beautiful day, and I'm so happy to be here. I'm so happy to"]
```"""
# 1. init greedy_search values
......@@ -2634,6 +2634,7 @@ class TFGenerationMixin:
Examples:
```python
>>> import tensorflow as tf
>>> from transformers import (
... AutoTokenizer,
... TFAutoModelForCausalLM,
......@@ -2666,9 +2667,11 @@ class TFGenerationMixin:
... ]
... )
>>> tf.random.set_seed(0)
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today is a beautiful day, and it\'s all about the future."\n\nThe announcement comes three']
```"""
# 1. init greedy_search values
......@@ -2927,15 +2930,15 @@ class TFGenerationMixin:
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = tf.ones((num_beams, 1), dtype=tf.int64)
>>> input_ids = tf.ones((1, num_beams, 1), dtype=tf.int32)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(
... tf.repeat(encoder_input_ids, num_beams, axis=0), return_dict=True
>>> encoder_outputs = model.get_encoder()(encoder_input_ids, return_dict=True)
>>> encoder_outputs.last_hidden_state = tf.repeat(
... tf.expand_dims(encoder_outputs.last_hidden_state, axis=0), num_beams, axis=1
... )
... }
>>> model_kwargs = {"encoder_outputs": encoder_outputs}
>>> # instantiate logits processors
>>> logits_processor = TFLogitsProcessorList(
......@@ -2944,7 +2947,8 @@ class TFGenerationMixin:
>>> outputs = model.beam_search(input_ids, logits_processor=logits_processor, **model_kwargs)
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
```"""
def flatten_beam_dim(tensor, batch_axis=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment