Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
deff792b
Commit
deff792b
authored
Dec 25, 2019
by
patrickvonplaten
Browse files
add prepare inputs for transfo_xl and xlnet
parent
9398058e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
8 deletions
+23
-8
src/transformers/modeling_transfo_xl.py
src/transformers/modeling_transfo_xl.py
+9
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+7
-7
src/transformers/modeling_xlnet.py
src/transformers/modeling_xlnet.py
+7
-1
No files found.
src/transformers/modeling_transfo_xl.py
View file @
deff792b
...
@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return
self
.
out_layer
return
self
.
out_layer
else
:
else
:
return
self
.
crit
.
out_layers
[
-
1
]
return
self
.
crit
.
out_layers
[
-
1
]
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
model_kwargs
):
inputs
=
{
"input_ids"
:
input_ids
}
# if past is defined in model kwargs then use it for faster decoding
if
'past'
in
model_kwargs
and
model_kwargs
[
'past'
]:
inputs
[
'mems'
]
=
model_kwargs
[
'past'
]
return
inputs
src/transformers/modeling_utils.py
View file @
deff792b
...
@@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module):
...
@@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module):
return
{
"input_ids"
:
input_ids
}
return
{
"input_ids"
:
input_ids
}
def
_do_output_past
(
self
,
outputs
):
def
_do_output_past
(
self
,
outputs
):
# TODO: might be better to write a self.do_output_past method for each
# individual class as is done for prepare_inputs_for_generation
has_output_past
=
hasattr
(
self
.
config
,
'output_past'
)
and
self
.
config
.
output_past
has_output_past
=
hasattr
(
self
.
config
,
'output_past'
)
and
self
.
config
.
output_past
has_multiple_outputs
=
len
(
outputs
)
>
1
has_mem_len
=
hasattr
(
self
.
config
,
'mem_len'
)
and
self
.
config
.
mem_len
has_mem_len
=
hasattr
(
self
.
config
,
'mem_len'
)
if
has_output_past
and
has_m
ultiple_outputs
and
not
has_mem_len
:
if
has_output_past
and
not
has_m
em_len
and
len
(
outputs
)
>
1
:
return
True
return
True
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
elif
has_mem_len
and
self
.
config
.
mem_len
>
0
and
len
(
outputs
)
>
1
:
return
True
return
False
return
False
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module):
...
@@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module):
if
past
:
if
past
:
reordered_past
=
[]
reordered_past
=
[]
for
layer_past
in
past
:
for
layer_past
in
past
:
# copy the relevant beam idx past to past
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past
=
[
layer_past
[:,
i
].
unsqueeze
(
1
).
clone
().
detach
()
for
i
in
beam_idx
]
reordered_layer_past
=
[
layer_past
[:,
i
].
unsqueeze
(
1
).
clone
().
detach
()
for
i
in
beam_idx
]
reordered_layer_past
=
torch
.
cat
(
reordered_layer_past
,
dim
=
1
)
reordered_layer_past
=
torch
.
cat
(
reordered_layer_past
,
dim
=
1
)
# check that shape matches
# check that shape matches
...
...
src/transformers/modeling_xlnet.py
View file @
deff792b
...
@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
)
)
target_mapping
[
0
,
0
,
-
1
]
=
1.0
target_mapping
[
0
,
0
,
-
1
]
=
1.0
return
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
inputs
=
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
# if past is defined in model kwargs then use it for faster decoding
if
'past'
in
model_kwargs
and
model_kwargs
[
'past'
]:
inputs
[
'mems'
]
=
model_kwargs
[
'past'
]
return
inputs
def
forward
(
def
forward
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment