Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
492bea9a
Unverified
Commit
492bea9a
authored
Dec 27, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 27, 2019
Browse files
Merge pull request #2292 from patrickvonplaten/add_cached_past_for_language_generation
Add cached past for language generation
parents
e213900f
fc84bd52
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
73 additions
and
13 deletions
+73
-13
src/transformers/modeling_ctrl.py
src/transformers/modeling_ctrl.py
+9
-0
src/transformers/modeling_gpt2.py
src/transformers/modeling_gpt2.py
+9
-0
src/transformers/modeling_transfo_xl.py
src/transformers/modeling_transfo_xl.py
+9
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+39
-12
src/transformers/modeling_xlnet.py
src/transformers/modeling_xlnet.py
+7
-1
No files found.
src/transformers/modeling_ctrl.py
View file @
492bea9a
...
@@ -490,6 +490,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
...
@@ -490,6 +490,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if
"past"
in
kwargs
and
kwargs
[
"past"
]:
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
inputs
=
{
"input_ids"
:
input_ids
}
inputs
.
update
(
kwargs
)
return
inputs
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
...
src/transformers/modeling_gpt2.py
View file @
492bea9a
...
@@ -559,6 +559,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -559,6 +559,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if
"past"
in
kwargs
and
kwargs
[
"past"
]:
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
inputs
=
{
"input_ids"
:
input_ids
}
inputs
.
update
(
kwargs
)
return
inputs
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
...
src/transformers/modeling_transfo_xl.py
View file @
492bea9a
...
@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return
self
.
out_layer
return
self
.
out_layer
else
:
else
:
return
self
.
crit
.
out_layers
[
-
1
]
return
self
.
crit
.
out_layers
[
-
1
]
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
model_kwargs
):
inputs
=
{
"input_ids"
:
input_ids
}
# if past is defined in model kwargs then use it for faster decoding
if
"past"
in
model_kwargs
and
model_kwargs
[
"past"
]:
inputs
[
"mems"
]
=
model_kwargs
[
"past"
]
return
inputs
src/transformers/modeling_utils.py
View file @
492bea9a
...
@@ -539,6 +539,17 @@ class PreTrainedModel(nn.Module):
...
@@ -539,6 +539,17 @@ class PreTrainedModel(nn.Module):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
return
{
"input_ids"
:
input_ids
}
return
{
"input_ids"
:
input_ids
}
def
_do_output_past
(
self
,
outputs
):
has_output_past
=
hasattr
(
self
.
config
,
"output_past"
)
and
self
.
config
.
output_past
has_mem_len
=
hasattr
(
self
.
config
,
"mem_len"
)
and
self
.
config
.
mem_len
if
has_output_past
and
not
has_mem_len
and
len
(
outputs
)
>
1
:
return
True
elif
has_mem_len
and
self
.
config
.
mem_len
>
0
and
len
(
outputs
)
>
1
:
return
True
return
False
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
generate
(
def
generate
(
self
,
self
,
...
@@ -757,14 +768,17 @@ class PreTrainedModel(nn.Module):
...
@@ -757,14 +768,17 @@ class PreTrainedModel(nn.Module):
# current position / max lengths / length of generated sentences / unfinished sentences
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents
=
input_ids
.
new
(
batch_size
).
fill_
(
1
)
unfinished_sents
=
input_ids
.
new
(
batch_size
).
fill_
(
1
)
# TODO: add cached compute states
past
=
None
pasts
=
None
while
cur_len
<
max_length
:
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
past
s
=
past
s
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
past
=
past
)
outputs
=
self
(
**
model_inputs
)
outputs
=
self
(
**
model_inputs
)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
# if model has past, then set the past variable to speed up decoding
if
self
.
_do_output_past
(
outputs
):
past
=
outputs
[
1
]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
if
repetition_penalty
!=
1.0
:
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
@@ -838,15 +852,19 @@ class PreTrainedModel(nn.Module):
...
@@ -838,15 +852,19 @@ class PreTrainedModel(nn.Module):
beam_scores
=
beam_scores
.
view
(
-
1
)
# shape (batch_size * num_beams,)
beam_scores
=
beam_scores
.
view
(
-
1
)
# shape (batch_size * num_beams,)
# cache compute states
# cache compute states
past
s
=
None
# self.prepare_pasts()
past
=
None
# done sentences
# done sentences
done
=
[
False
for
_
in
range
(
batch_size
)]
done
=
[
False
for
_
in
range
(
batch_size
)]
while
cur_len
<
max_length
:
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
past
=
past
)
scores
=
self
(
**
model_inputs
)[
0
]
# (batch_size * num_beams, cur_len, vocab_size)
outputs
=
self
(
**
model_inputs
)
# (batch_size * num_beams, cur_len, vocab_size)
scores
=
scores
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
scores
=
outputs
[
0
][:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if
self
.
_do_output_past
(
outputs
):
past
=
outputs
[
1
]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
if
repetition_penalty
!=
1.0
:
...
@@ -935,13 +953,22 @@ class PreTrainedModel(nn.Module):
...
@@ -935,13 +953,22 @@ class PreTrainedModel(nn.Module):
beam_words
=
input_ids
.
new
([
x
[
1
]
for
x
in
next_batch_beam
])
beam_words
=
input_ids
.
new
([
x
[
1
]
for
x
in
next_batch_beam
])
beam_idx
=
input_ids
.
new
([
x
[
2
]
for
x
in
next_batch_beam
])
beam_idx
=
input_ids
.
new
([
x
[
2
]
for
x
in
next_batch_beam
])
# re-order batch
and internal states
# re-order batch
input_ids
=
input_ids
[
beam_idx
,
:]
input_ids
=
input_ids
[
beam_idx
,
:]
input_ids
=
torch
.
cat
([
input_ids
,
beam_words
.
unsqueeze
(
1
)],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
,
beam_words
.
unsqueeze
(
1
)],
dim
=-
1
)
# TODO: Activate cache
# for k in cache.keys():
# re-order internal states
# if k != 'slen':
if
past
:
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
reordered_past
=
[]
for
layer_past
in
past
:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past
=
[
layer_past
[:,
i
].
unsqueeze
(
1
).
clone
().
detach
()
for
i
in
beam_idx
]
reordered_layer_past
=
torch
.
cat
(
reordered_layer_past
,
dim
=
1
)
# check that shape matches
assert
reordered_layer_past
.
shape
==
layer_past
.
shape
reordered_past
.
append
(
reordered_layer_past
)
past
=
tuple
(
reordered_past
)
# update current length
# update current length
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
...
...
src/transformers/modeling_xlnet.py
View file @
492bea9a
...
@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
)
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
target_mapping
[
0
,
0
,
-
1
]
=
1.0
return
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
inputs
=
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
# if past is defined in model kwargs then use it for faster decoding
if
"past"
in
model_kwargs
and
model_kwargs
[
"past"
]:
inputs
[
"mems"
]
=
model_kwargs
[
"past"
]
return
inputs
def
forward
(
def
forward
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment