Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8a377c3d
Unverified
Commit
8a377c3d
authored
Jun 18, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 18, 2020
Browse files
[fix] Move _adjust_logits above postprocess to fix Marian.generate (#5126)
parent
3d3e605a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
7 deletions
+8
-7
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+1
-1
src/transformers/modeling_marian.py
src/transformers/modeling_marian.py
+1
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+6
-5
No files found.
src/transformers/modeling_bart.py
View file @
8a377c3d
...
...
@@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
}
def
prepare
_logits_
for
_generation
(
self
,
logits
,
cur_len
,
max_length
):
def
adjust
_logits_
during
_generation
(
self
,
logits
,
cur_len
,
max_length
):
if
cur_len
==
1
:
self
.
_force_token_ids_generation
(
logits
,
self
.
config
.
bos_token_id
)
if
cur_len
==
max_length
-
1
and
self
.
config
.
eos_token_id
is
not
None
:
...
...
src/transformers/modeling_marian.py
View file @
8a377c3d
...
...
@@ -46,7 +46,7 @@ class MarianMTModel(BartForConditionalGeneration):
"""
def
prepare
_logits_
for
_generation
(
self
,
logits
,
cur_len
,
max_length
):
def
adjust
_logits_
during
_generation
(
self
,
logits
,
cur_len
,
max_length
):
logits
[:,
self
.
config
.
pad_token_id
]
=
float
(
"-inf"
)
if
cur_len
==
max_length
-
1
and
self
.
config
.
eos_token_id
is
not
None
:
self
.
_force_token_ids_generation
(
logits
,
self
.
config
.
eos_token_id
)
...
...
src/transformers/modeling_utils.py
View file @
8a377c3d
...
...
@@ -792,7 +792,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
return
{
"input_ids"
:
input_ids
}
def
prepare
_logits_
for
_generation
(
self
,
logits
,
**
kwargs
):
def
adjust
_logits_
during
_generation
(
self
,
logits
,
**
kwargs
):
return
logits
def
_use_cache
(
self
,
outputs
,
use_cache
):
...
...
@@ -1396,6 +1396,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# if model has past, then set the past variable to speed up decoding
if
self
.
_use_cache
(
outputs
,
use_cache
):
past
=
outputs
[
1
]
if
self
.
config
.
is_encoder_decoder
and
do_sample
is
False
:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
,
max_length
=
max_length
)
scores
=
F
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
...
...
@@ -1413,10 +1418,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_beams
=
num_beams
,
)
if
self
.
config
.
is_encoder_decoder
and
do_sample
is
False
:
# TODO (PVP) still a bit hacky here - there might be a better solution
scores
=
self
.
prepare_logits_for_generation
(
scores
,
cur_len
=
cur_len
,
max_length
=
max_length
)
assert
scores
.
shape
==
(
batch_size
*
num_beams
,
vocab_size
),
"Shapes of scores: {} != {}"
.
format
(
scores
.
shape
,
(
batch_size
*
num_beams
,
vocab_size
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment