Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
deff792b
Commit
deff792b
authored
Dec 25, 2019
by
patrickvonplaten
Browse files
add prepare inputs for transfo_xl and xlnet
parent
9398058e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
8 deletions
+23
-8
src/transformers/modeling_transfo_xl.py
src/transformers/modeling_transfo_xl.py
+9
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+7
-7
src/transformers/modeling_xlnet.py
src/transformers/modeling_xlnet.py
+7
-1
No files found.
src/transformers/modeling_transfo_xl.py
View file @
deff792b
...
@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return
self
.
out_layer
return
self
.
out_layer
else
:
else
:
return
self
.
crit
.
out_layers
[
-
1
]
return
self
.
crit
.
out_layers
[
-
1
]
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
model_kwargs
):
inputs
=
{
"input_ids"
:
input_ids
}
# if past is defined in model kwargs then use it for faster decoding
if
'past'
in
model_kwargs
and
model_kwargs
[
'past'
]:
inputs
[
'mems'
]
=
model_kwargs
[
'past'
]
return
inputs
src/transformers/modeling_utils.py
View file @
deff792b
...
@@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module):
...
@@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module):
return
{
"input_ids"
:
input_ids
}
return
{
"input_ids"
:
input_ids
}
def
_do_output_past
(
self
,
outputs
):
def
_do_output_past
(
self
,
outputs
):
# TODO: might be better to write a self.do_output_past method for each
# individual class as is done for prepare_inputs_for_generation
has_output_past
=
hasattr
(
self
.
config
,
'output_past'
)
and
self
.
config
.
output_past
has_output_past
=
hasattr
(
self
.
config
,
'output_past'
)
and
self
.
config
.
output_past
has_multiple_outputs
=
len
(
outputs
)
>
1
has_mem_len
=
hasattr
(
self
.
config
,
'mem_len'
)
and
self
.
config
.
mem_len
has_mem_len
=
hasattr
(
self
.
config
,
'mem_len'
)
if
has_output_past
and
has_m
ultiple_outputs
and
not
has_mem_len
:
if
has_output_past
and
not
has_m
em_len
and
len
(
outputs
)
>
1
:
return
True
return
True
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
elif
has_mem_len
and
self
.
config
.
mem_len
>
0
and
len
(
outputs
)
>
1
:
return
True
return
False
return
False
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module):
...
@@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module):
if
past
:
if
past
:
reordered_past
=
[]
reordered_past
=
[]
for
layer_past
in
past
:
for
layer_past
in
past
:
# copy the relevant beam idx past to past
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past
=
[
layer_past
[:,
i
].
unsqueeze
(
1
).
clone
().
detach
()
for
i
in
beam_idx
]
reordered_layer_past
=
[
layer_past
[:,
i
].
unsqueeze
(
1
).
clone
().
detach
()
for
i
in
beam_idx
]
reordered_layer_past
=
torch
.
cat
(
reordered_layer_past
,
dim
=
1
)
reordered_layer_past
=
torch
.
cat
(
reordered_layer_past
,
dim
=
1
)
# check that shape matches
# check that shape matches
...
...
src/transformers/modeling_xlnet.py
View file @
deff792b
...
@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
)
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
target_mapping
[
0
,
0
,
-
1
]
=
1.0
return
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
inputs
=
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
# if past is defined in model kwargs then use it for faster decoding
if
'past'
in
model_kwargs
and
model_kwargs
[
'past'
]:
inputs
[
'mems'
]
=
model_kwargs
[
'past'
]
return
inputs
def
forward
(
def
forward
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment