Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7e0c5c73
Commit
7e0c5c73
authored
Dec 23, 2019
by
patrickvonplaten
Browse files
changed do_output_past function to check for self.config.output_past instead of self.output_past
parent
eeaa402c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+5
-5
No files found.
src/transformers/modeling_utils.py
View file @
7e0c5c73
...
@@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module):
...
@@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
return
{
"input_ids"
:
input_ids
}
return
{
"input_ids"
:
input_ids
}
def
_
has
_past
(
self
,
outputs
):
def
_
do_output
_past
(
self
,
outputs
):
# TODO: might be better to write a self.
has
_past method for each individual class as is done for
# TODO: might be better to write a self.
do_output
_past method for each individual class as is done for
# prepare_inputs_for_generation
# prepare_inputs_for_generation
if
hasattr
(
self
,
'output_past'
)
and
self
.
output_past
and
len
(
outputs
)
>
1
:
if
hasattr
(
self
.
config
,
'output_past'
)
and
self
.
config
.
output_past
and
len
(
outputs
)
>
1
and
not
hasattr
(
self
,
'mem_len'
)
:
return
True
return
True
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
return
False
return
False
...
@@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module):
...
@@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module):
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
# if model has past, then set the past variable to speed up decoding
# if model has past, then set the past variable to speed up decoding
if
self
.
_
has
_past
(
outputs
):
if
self
.
_
do_output
_past
(
outputs
):
past
=
outputs
[
1
]
past
=
outputs
[
1
]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
...
@@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module):
...
@@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module):
scores
=
outputs
[
0
][:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
scores
=
outputs
[
0
][:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
# if model has past, then set the past variable to speed up decoding
if
self
.
_
has
_past
(
outputs
):
if
self
.
_
do_output
_past
(
outputs
):
past
=
outputs
[
1
]
past
=
outputs
[
1
]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment