@@ -459,6 +459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -459,6 +459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
self,
self,
input_ids=None,
input_ids=None,
max_length=None,
max_length=None,
min_length=None,
do_sample=True,
do_sample=True,
early_stopping=False,
early_stopping=False,
num_beams=None,
num_beams=None,
...
@@ -470,7 +471,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -470,7 +471,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id=None,
pad_token_id=None,
eos_token_ids=None,
eos_token_ids=None,
length_penalty=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
num_return_sequences=None,
attention_mask=None,
):
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
and beam-search.
...
@@ -564,6 +567,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -564,6 +567,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):