Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f4acd05
Unverified
Commit
9f4acd05
authored
Sep 14, 2022
by
Ekagra Ranjan
Committed by
GitHub
Sep 14, 2022
Browse files
Generate: add missing comments after refactoring of generate() (#18981)
parent
59407bbe
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
2 deletions
+9
-2
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+9
-2
No files found.
src/transformers/generation_utils.py
View file @
9f4acd05
...
@@ -2240,6 +2240,8 @@ class GenerationMixin:
...
@@ -2240,6 +2240,8 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
=
beam_scores
.
view
((
batch_size
*
num_beams
,))
beam_scores
=
beam_scores
.
view
((
batch_size
*
num_beams
,))
...
@@ -2303,6 +2305,7 @@ class GenerationMixin:
...
@@ -2303,6 +2305,7 @@ class GenerationMixin:
vocab_size
=
next_token_scores
.
shape
[
-
1
]
vocab_size
=
next_token_scores
.
shape
[
-
1
]
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
next_token_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
)
...
@@ -2873,9 +2876,9 @@ class GenerationMixin:
...
@@ -2873,9 +2876,9 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
)
beam_scores
=
torch
.
full
((
batch_size
,
num_beams
),
-
1e9
,
dtype
=
torch
.
float
,
device
=
device
)
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
# the same group don't produce same tokens everytime.
beam_scores
=
torch
.
full
((
batch_size
,
num_beams
),
-
1e9
,
dtype
=
torch
.
float
,
device
=
device
)
beam_scores
[:,
::
num_sub_beams
]
=
0
beam_scores
[:,
::
num_sub_beams
]
=
0
beam_scores
=
beam_scores
.
view
((
batch_size
*
num_beams
,))
beam_scores
=
beam_scores
.
view
((
batch_size
*
num_beams
,))
...
@@ -2951,6 +2954,7 @@ class GenerationMixin:
...
@@ -2951,6 +2954,7 @@ class GenerationMixin:
# reshape for beam search
# reshape for beam search
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
group_size
*
vocab_size
)
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
group_size
*
vocab_size
)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
2
*
group_size
,
dim
=
1
,
largest
=
True
,
sorted
=
True
next_token_scores
,
2
*
group_size
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
)
...
@@ -3235,6 +3239,8 @@ class GenerationMixin:
...
@@ -3235,6 +3239,8 @@ class GenerationMixin:
f
"Batch dimension of `input_ids` should be
{
num_beams
*
batch_size
}
, but is
{
batch_beam_size
}
."
f
"Batch dimension of `input_ids` should be
{
num_beams
*
batch_size
}
, but is
{
batch_beam_size
}
."
)
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
=
beam_scores
.
view
((
batch_size
*
num_beams
,))
beam_scores
=
beam_scores
.
view
((
batch_size
*
num_beams
,))
...
@@ -3301,6 +3307,7 @@ class GenerationMixin:
...
@@ -3301,6 +3307,7 @@ class GenerationMixin:
vocab_size
=
next_token_scores
.
shape
[
-
1
]
vocab_size
=
next_token_scores
.
shape
[
-
1
]
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
next_token_scores
=
next_token_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
next_tokens
=
torch
.
topk
(
next_token_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
next_token_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment