Unverified Commit 9a86321b authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

tf generation utils: remove unused kwargs (#6591)

parent 2a7402cb
...@@ -284,7 +284,7 @@ class TFGenerationMixin: ...@@ -284,7 +284,7 @@ class TFGenerationMixin:
pad_token_id = eos_token_id pad_token_id = eos_token_id
# current position and vocab size # current position and vocab size
cur_len = shape_list(input_ids)[1] cur_len = shape_list(input_ids)[1] # unused
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
# set effective batch size and effective batch multiplier according to do_sample # set effective batch size and effective batch multiplier according to do_sample
...@@ -366,10 +366,8 @@ class TFGenerationMixin: ...@@ -366,10 +366,8 @@ class TFGenerationMixin:
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
length_penalty=length_penalty, length_penalty=length_penalty,
...@@ -392,10 +390,8 @@ class TFGenerationMixin: ...@@ -392,10 +390,8 @@ class TFGenerationMixin:
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
vocab_size=vocab_size, vocab_size=vocab_size,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
...@@ -418,10 +414,8 @@ class TFGenerationMixin: ...@@ -418,10 +414,8 @@ class TFGenerationMixin:
repetition_penalty, repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size,
bad_words_ids, bad_words_ids,
bos_token_id,
pad_token_id, pad_token_id,
eos_token_id, eos_token_id,
decoder_start_token_id,
batch_size, batch_size,
vocab_size, vocab_size,
encoder_outputs, encoder_outputs,
...@@ -582,9 +576,7 @@ class TFGenerationMixin: ...@@ -582,9 +576,7 @@ class TFGenerationMixin:
repetition_penalty, repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size,
bad_words_ids, bad_words_ids,
bos_token_id,
pad_token_id, pad_token_id,
decoder_start_token_id,
eos_token_id, eos_token_id,
batch_size, batch_size,
num_return_sequences, num_return_sequences,
...@@ -616,6 +608,7 @@ class TFGenerationMixin: ...@@ -616,6 +608,7 @@ class TFGenerationMixin:
# cache compute states # cache compute states
past = encoder_outputs past = encoder_outputs
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
# done sentences # done sentences
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment