Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
568e5783
"sgl-router/src/vscode:/vscode.git/clone" did not exist on "d3be97104b09bfcabc7de507bfe8a79455ebce30"
Unverified
Commit
568e5783
authored
Oct 27, 2022
by
Joao Gante
Committed by
GitHub
Oct 27, 2022
Browse files
Generate: contrastive search uses existing abstractions and conventions (#19896)
parent
803475fb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
50 deletions
+60
-50
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+60
-50
No files found.
src/transformers/generation_utils.py
View file @
568e5783
...
...
@@ -54,7 +54,7 @@ from .generation_stopping_criteria import (
StoppingCriteriaList
,
validate_stopping_criteria
,
)
from
.modeling_outputs
import
CausalLMOutputWith
CrossAttentions
,
Seq2SeqLMOutput
from
.modeling_outputs
import
CausalLMOutputWith
Past
,
Seq2SeqLMOutput
from
.models.auto
import
(
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
...
...
@@ -1882,28 +1882,34 @@ class GenerationMixin:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step
# (2) last_hidden_states; (3) logit_for_next_step
; (4) update model kwargs for the next step
if
model_kwargs
.
get
(
"past"
)
is
None
:
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
output
=
self
(
**
model_inputs
,
output_hidden_states
=
True
,
output_attentions
=
True
)
outputs
=
self
(
**
model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
)
# past_key_values is required for fast decoding
if
"past_key_values"
not
in
output
:
if
"past_key_values"
not
in
output
s
:
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
cannot return `past_key_values` and can therefore **not** be used "
"for contrastive search."
)
past_key_values
=
output
.
past_key_values
past_key_values
=
output
s
.
past_key_values
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens)
if
self
.
config
.
is_encoder_decoder
:
last_hidden_states
=
output
.
decoder_hidden_states
[
-
1
]
last_hidden_states
=
output
s
.
decoder_hidden_states
[
-
1
]
else
:
last_hidden_states
=
output
.
hidden_states
[
-
1
]
last_hidden_states
=
output
s
.
hidden_states
[
-
1
]
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step
=
output
.
logits
[:,
-
1
,
:]
logit_for_next_step
=
outputs
.
logits
[:,
-
1
,
:]
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
...
...
@@ -1918,6 +1924,18 @@ class GenerationMixin:
_
,
top_k_ids
=
torch
.
topk
(
logit_for_next_step
,
dim
=-
1
,
k
=
top_k
)
top_k_probs
=
torch
.
gather
(
next_probs
,
dim
=
1
,
index
=
top_k_ids
)
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
output_scores
:
scores
+=
(
logit_for_next_step
,)
if
output_hidden_states
:
decoder_hidden_states
+=
(
(
outputs
.
decoder_hidden_states
,)
if
self
.
config
.
is_encoder_decoder
else
(
outputs
.
hidden_states
,)
)
# enlarge the past_key_values
new_key_values
=
[]
for
layer
in
past_key_values
:
...
...
@@ -1937,10 +1955,7 @@ class GenerationMixin:
# build next attention mask
if
"attention_mask"
in
model_inputs
:
attention_mask
=
model_inputs
[
"attention_mask"
]
# [B, S]
# decoder-only model need the full attention mask, not only the mask for the last token
if
self
.
config
.
is_encoder_decoder
is
False
:
attention_mask
=
torch
.
cat
([
attention_mask
,
attention_mask
.
new_ones
((
bsz
,
1
))],
dim
=-
1
)
attention_mask
=
model_kwargs
[
"attention_mask"
]
# [B, S]
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
top_k
,
-
1
).
reshape
(
-
1
,
attention_mask
.
size
(
-
1
))
else
:
attention_mask
=
None
...
...
@@ -1958,27 +1973,26 @@ class GenerationMixin:
encoder_outputs
=
encoder_outputs
,
)
# compute the candidate tokens by the language model and collects their hidden_states
output
=
self
(
output_hidden_states
=
True
,
**
next_model_inputs
)
past_key_values
=
output
.
past_key_values
outputs
=
self
(
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
)
past_key_values
=
outputs
.
past_key_values
logits
=
output
.
logits
[:,
-
1
,
:]
logits
=
output
s
.
logits
[:,
-
1
,
:]
# name is different for encoder-decoder and decoder-only models
if
self
.
config
.
is_encoder_decoder
:
next_hidden
=
output
.
decoder_hidden_states
[
-
1
]
full_hidden_states
=
output
.
decoder_hidden_states
next_hidden
=
output
s
.
decoder_hidden_states
[
-
1
]
full_hidden_states
=
output
s
.
decoder_hidden_states
else
:
next_hidden
=
output
.
hidden_states
[
-
1
]
full_hidden_states
=
output
.
hidden_states
next_hidden
=
output
s
.
hidden_states
[
-
1
]
full_hidden_states
=
output
s
.
hidden_states
context_hidden
=
(
last_hidden_states
.
unsqueeze
(
1
).
expand
(
-
1
,
top_k
,
-
1
,
-
1
).
reshape
(
bsz
*
top_k
,
seqlen
,
embed_dim
)
)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
# the scores and index of the selected tokens are returned
selected_scores
,
selected_idx
=
ranking_fast
(
context_hidden
,
next_hidden
,
top_k_probs
,
penalty_alpha
,
top_k
)
selected_idx
=
_ranking_fast
(
context_hidden
,
next_hidden
,
top_k_probs
,
penalty_alpha
,
top_k
)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
...
...
@@ -1988,11 +2002,11 @@ class GenerationMixin:
next_hidden
=
next_hidden
[
range
(
bsz
),
selected_idx
,
:]
last_hidden_states
=
torch
.
cat
([
last_hidden_states
,
next_hidden
.
unsqueeze
(
1
)],
dim
=
1
)
decoder_hidden_states
_one_step
=
[]
decoder_hidden_states
=
[]
for
layer
in
full_hidden_states
:
layer
=
torch
.
stack
(
torch
.
split
(
layer
.
squeeze
(
dim
=
1
),
top_k
))
layer
=
layer
[
range
(
bsz
),
selected_idx
,
:]
decoder_hidden_states
_one_step
.
append
(
layer
)
decoder_hidden_states
.
append
(
layer
)
# select the past_key_value
new_key_values
=
[]
...
...
@@ -2009,21 +2023,24 @@ class GenerationMixin:
past_key_values
=
new_key_values
logit_for_next_step
=
torch
.
stack
(
torch
.
split
(
logits
,
top_k
))[
range
(
bsz
),
selected_idx
,
:]
# contrastive_search main logic end::
# after running the above codes, we update following parameters: next_tokens, past_key_values,
# logit_for_next_step, selected_score, decoder_hidden_states_one_step
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if
self
.
config
.
is_encoder_decoder
:
outputs
=
Seq2SeqLMOutput
(
past_key_values
=
past_key_values
,
decoder_hidden_states
=
decoder_hidden_states
,
)
else
:
outputs
=
CausalLMOutputWithPast
(
past_key_values
=
past_key_values
,
hidden_states
=
decoder_hidden_states
,
attentions
=
model_kwargs
[
"attention_mask"
],
)
# contrastive_search main logic end
if
synced_gpus
and
this_peer_finished
:
continue
# don't waste resources running the code we don't need
# Store scores, attentions and hidden_states when required
if
return_dict_in_generate
:
if
output_scores
:
scores
+=
(
selected_scores
,)
if
output_hidden_states
:
decoder_hidden_states
+=
(
decoder_hidden_states_one_step
,)
# finished sentences should have their next token be a padding token
if
eos_token_id
is
not
None
:
if
pad_token_id
is
None
:
...
...
@@ -2032,14 +2049,6 @@ class GenerationMixin:
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
if
self
.
config
.
is_encoder_decoder
:
outputs
=
Seq2SeqLMOutput
(
past_key_values
=
past_key_values
,
)
else
:
outputs
=
CausalLMOutputWithCrossAttentions
(
past_key_values
=
past_key_values
,
attentions
=
model_kwargs
[
"attention_mask"
]
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
...
...
@@ -3884,17 +3893,18 @@ def top_k_top_p_filtering(
return
logits
def
ranking_fast
(
def
_
ranking_fast
(
context_hidden
:
torch
.
FloatTensor
,
next_hidden
:
torch
.
FloatTensor
,
next_top_k_probs
:
torch
.
FloatTensor
,
alpha
:
float
,
beam_width
:
int
,
)
->
Tuple
[
torch
.
FloatTensor
]
:
)
->
torch
.
FloatTensor
:
"""
context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam
Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each
row in the batch.
"""
_
,
context_len
,
embed_dim
=
context_hidden
.
size
()
norm_context_hidden
=
context_hidden
/
context_hidden
.
norm
(
dim
=
2
,
keepdim
=
True
)
norm_next_hidden
=
next_hidden
/
next_hidden
.
norm
(
dim
=
2
,
keepdim
=
True
)
cosine_matrix
=
torch
.
matmul
(
norm_context_hidden
,
norm_next_hidden
.
transpose
(
1
,
2
)).
squeeze
(
-
1
)
# [B*K, S]
...
...
@@ -3902,5 +3912,5 @@ def ranking_fast(
next_top_k_probs
=
next_top_k_probs
.
view
(
-
1
)
# [B*K]
contrastive_score
=
(
1.0
-
alpha
)
*
next_top_k_probs
-
alpha
*
degeneration_penalty
contrastive_score
=
torch
.
stack
(
torch
.
split
(
contrastive_score
,
beam_width
))
# [B, K]
selected_scores
,
selected_idx
=
contrastive_score
.
max
(
dim
=-
1
)
# [B]
return
torch
.
log
(
selected_scores
),
selected_idx
_
,
selected_idx
=
contrastive_score
.
max
(
dim
=-
1
)
# [B]
return
selected_idx
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment