"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "de635af3f1ef740aa32f53a91473269c6435e19e"
Unverified Commit a6c82d45 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: return `past_key_values` (#25086)

parent 441c3e0d
...@@ -104,12 +104,20 @@ class GreedySearchDecoderOnlyOutput(ModelOutput): ...@@ -104,12 +104,20 @@ class GreedySearchDecoderOnlyOutput(ModelOutput):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -140,6 +148,13 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): ...@@ -140,6 +148,13 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
...@@ -149,6 +164,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): ...@@ -149,6 +164,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -169,15 +185,23 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput): ...@@ -169,15 +185,23 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is
passed or when `config.output_hidden_states=True`): passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length,
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -211,6 +235,13 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): ...@@ -211,6 +235,13 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
...@@ -220,6 +251,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): ...@@ -220,6 +251,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -243,12 +275,20 @@ class SampleDecoderOnlyOutput(ModelOutput): ...@@ -243,12 +275,20 @@ class SampleDecoderOnlyOutput(ModelOutput):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -283,6 +323,13 @@ class SampleEncoderDecoderOutput(ModelOutput): ...@@ -283,6 +323,13 @@ class SampleEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
...@@ -292,6 +339,7 @@ class SampleEncoderDecoderOutput(ModelOutput): ...@@ -292,6 +339,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -319,6 +367,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -319,6 +367,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
...@@ -327,6 +382,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -327,6 +382,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
beam_indices: Optional[torch.LongTensor] = None beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -366,6 +422,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -366,6 +422,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
...@@ -377,6 +440,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -377,6 +440,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -404,6 +468,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -404,6 +468,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
...@@ -412,6 +483,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -412,6 +483,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
beam_indices: Optional[torch.LongTensor] = None beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass @dataclass
...@@ -450,6 +522,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -450,6 +522,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
...@@ -461,6 +540,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -461,6 +540,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
...@@ -2148,8 +2228,8 @@ class GenerationMixin: ...@@ -2148,8 +2228,8 @@ class GenerationMixin:
items.append(item.repeat_interleave(1, dim=0)) items.append(item.repeat_interleave(1, dim=0))
else: else:
items.append(item.repeat_interleave(top_k, dim=0)) items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(items) new_key_values.append(tuple(items))
model_kwargs["past_key_values"] = new_key_values model_kwargs["past_key_values"] = tuple(new_key_values)
if sequential: if sequential:
all_outputs = {key: [] for key in outputs} # defined in first loop iteration all_outputs = {key: [] for key in outputs} # defined in first loop iteration
...@@ -2330,6 +2410,17 @@ class GenerationMixin: ...@@ -2330,6 +2410,17 @@ class GenerationMixin:
streamer.end() streamer.end()
if return_dict_in_generate: if return_dict_in_generate:
# Contrastive search works by forward looking at the next token, so we need to exclude it from
# `past_key_values` to be consistent with the other decoding methods
if model_kwargs.get("past_key_values") is not None:
past_key_values = []
for layer in model_kwargs["past_key_values"]:
layer_past_key_values = []
for item in layer:
layer_past_key_values.append(item[..., :-1, :])
past_key_values.append(tuple(layer_past_key_values))
model_kwargs["past_key_values"] = tuple(past_key_values)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return ContrastiveSearchEncoderDecoderOutput( return ContrastiveSearchEncoderDecoderOutput(
sequences=input_ids, sequences=input_ids,
...@@ -2339,6 +2430,7 @@ class GenerationMixin: ...@@ -2339,6 +2430,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions, cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return ContrastiveSearchDecoderOnlyOutput( return ContrastiveSearchDecoderOnlyOutput(
...@@ -2346,6 +2438,7 @@ class GenerationMixin: ...@@ -2346,6 +2438,7 @@ class GenerationMixin:
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return input_ids return input_ids
...@@ -2598,6 +2691,7 @@ class GenerationMixin: ...@@ -2598,6 +2691,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions, cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return GreedySearchDecoderOnlyOutput( return GreedySearchDecoderOnlyOutput(
...@@ -2605,6 +2699,7 @@ class GenerationMixin: ...@@ -2605,6 +2699,7 @@ class GenerationMixin:
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return input_ids return input_ids
...@@ -2880,6 +2975,7 @@ class GenerationMixin: ...@@ -2880,6 +2975,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions, cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return SampleDecoderOnlyOutput( return SampleDecoderOnlyOutput(
...@@ -2887,6 +2983,7 @@ class GenerationMixin: ...@@ -2887,6 +2983,7 @@ class GenerationMixin:
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return input_ids return input_ids
...@@ -3201,6 +3298,7 @@ class GenerationMixin: ...@@ -3201,6 +3298,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions, cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return BeamSearchDecoderOnlyOutput( return BeamSearchDecoderOnlyOutput(
...@@ -3210,6 +3308,7 @@ class GenerationMixin: ...@@ -3210,6 +3308,7 @@ class GenerationMixin:
beam_indices=sequence_outputs["beam_indices"], beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
...@@ -3530,6 +3629,7 @@ class GenerationMixin: ...@@ -3530,6 +3629,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions, cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return BeamSampleDecoderOnlyOutput( return BeamSampleDecoderOnlyOutput(
...@@ -3539,6 +3639,7 @@ class GenerationMixin: ...@@ -3539,6 +3639,7 @@ class GenerationMixin:
beam_indices=sequence_outputs["beam_indices"], beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
...@@ -3909,6 +4010,7 @@ class GenerationMixin: ...@@ -3909,6 +4010,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions, cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return BeamSearchDecoderOnlyOutput( return BeamSearchDecoderOnlyOutput(
...@@ -3918,6 +4020,7 @@ class GenerationMixin: ...@@ -3918,6 +4020,7 @@ class GenerationMixin:
beam_indices=sequence_outputs["beam_indices"], beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
...@@ -4244,6 +4347,7 @@ class GenerationMixin: ...@@ -4244,6 +4347,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions, cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return BeamSearchDecoderOnlyOutput( return BeamSearchDecoderOnlyOutput(
...@@ -4253,6 +4357,7 @@ class GenerationMixin: ...@@ -4253,6 +4357,7 @@ class GenerationMixin:
beam_indices=sequence_outputs["beam_indices"], beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
...@@ -4672,6 +4777,7 @@ class GenerationMixin: ...@@ -4672,6 +4777,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions, cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return GreedySearchDecoderOnlyOutput( return GreedySearchDecoderOnlyOutput(
...@@ -4679,6 +4785,7 @@ class GenerationMixin: ...@@ -4679,6 +4785,7 @@ class GenerationMixin:
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
) )
else: else:
return input_ids return input_ids
......
...@@ -1829,6 +1829,85 @@ class GenerationTesterMixin: ...@@ -1829,6 +1829,85 @@ class GenerationTesterMixin:
outputs_from_embeds_wo_ids[:, 1:].tolist(), outputs_from_embeds_wo_ids[:, 1:].tolist(),
) )
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
# won't fix: old models with unique inputs/caches/others
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
return
# may fix in the future: needs modeling or test input preparation fixes for compatibility
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
return
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
return
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
config.use_cache = True
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
model = model_class(config).to(torch_device)
model.eval()
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
return
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[-1]
if config.is_encoder_decoder:
inputs["decoder_input_ids"] = outputs_cached.sequences
if "decoder_attention_mask" in inputs:
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
inputs["decoder_attention_mask"],
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
mode="constant",
value=1,
)
else:
inputs["input_ids"] = outputs_cached.sequences
if "attention_mask" in inputs:
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"],
(0, new_attention_len - inputs["attention_mask"].shape[1]),
mode="constant",
value=1,
)
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
# The two sets of generated text and past kv should be equal to each other
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences num_sequences_in_output = batch_size * num_return_sequences
...@@ -1894,6 +1973,24 @@ class GenerationTesterMixin: ...@@ -1894,6 +1973,24 @@ class GenerationTesterMixin:
use_cache=use_cache, use_cache=use_cache,
) )
# Past Key Value States -- two notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer")
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
)
if use_cache and has_standard_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
num_sequences_in_output,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
def _check_scores(self, batch_size, scores, length, config): def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size) expected_shape = (batch_size, config.vocab_size)
self.assertIsInstance(scores, tuple) self.assertIsInstance(scores, tuple)
...@@ -1959,6 +2056,30 @@ class GenerationTesterMixin: ...@@ -1959,6 +2056,30 @@ class GenerationTesterMixin:
[encoder_expected_shape] * len(hidden_states), [encoder_expected_shape] * len(hidden_states),
) )
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, tuple)
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
[True] * len(past_key_values),
)
# (batch, head, seq_length, head_features)
expected_shape = (
batch_size * num_beam_groups,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
seq_length,
config.hidden_size // config.num_attention_heads,
)
# check shape key, value
self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
def _check_sequence_inside_sequence(self, tensor_1, tensor_2): def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device. # set to same device. we don't care what device.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment