Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fc84bd52
Commit
fc84bd52
authored
Dec 25, 2019
by
patrickvonplaten
Browse files
adapt style to predefined style layout
parent
deff792b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
8 additions
and
8 deletions
+8
-8
src/transformers/modeling_ctrl.py
src/transformers/modeling_ctrl.py
+1
-1
src/transformers/modeling_gpt2.py
src/transformers/modeling_gpt2.py
+1
-1
src/transformers/modeling_transfo_xl.py
src/transformers/modeling_transfo_xl.py
+2
-2
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-2
src/transformers/modeling_xlnet.py
src/transformers/modeling_xlnet.py
+2
-2
No files found.
src/transformers/modeling_ctrl.py
View file @
fc84bd52
...
@@ -492,7 +492,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
...
@@ -492,7 +492,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
# only last token for inputs_ids if past is defined in kwargs
# only last token for inputs_ids if past is defined in kwargs
if
'
past
'
in
kwargs
and
kwargs
[
'
past
'
]:
if
"
past
"
in
kwargs
and
kwargs
[
"
past
"
]:
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
inputs
=
{
"input_ids"
:
input_ids
}
inputs
=
{
"input_ids"
:
input_ids
}
...
...
src/transformers/modeling_gpt2.py
View file @
fc84bd52
...
@@ -561,7 +561,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -561,7 +561,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
# only last token for inputs_ids if past is defined in kwargs
# only last token for inputs_ids if past is defined in kwargs
if
'
past
'
in
kwargs
and
kwargs
[
'
past
'
]:
if
"
past
"
in
kwargs
and
kwargs
[
"
past
"
]:
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
inputs
=
{
"input_ids"
:
input_ids
}
inputs
=
{
"input_ids"
:
input_ids
}
...
...
src/transformers/modeling_transfo_xl.py
View file @
fc84bd52
...
@@ -935,7 +935,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -935,7 +935,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
inputs
=
{
"input_ids"
:
input_ids
}
inputs
=
{
"input_ids"
:
input_ids
}
# if past is defined in model kwargs then use it for faster decoding
# if past is defined in model kwargs then use it for faster decoding
if
'
past
'
in
model_kwargs
and
model_kwargs
[
'
past
'
]:
if
"
past
"
in
model_kwargs
and
model_kwargs
[
"
past
"
]:
inputs
[
'
mems
'
]
=
model_kwargs
[
'
past
'
]
inputs
[
"
mems
"
]
=
model_kwargs
[
"
past
"
]
return
inputs
return
inputs
src/transformers/modeling_utils.py
View file @
fc84bd52
...
@@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module):
...
@@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module):
return
{
"input_ids"
:
input_ids
}
return
{
"input_ids"
:
input_ids
}
def
_do_output_past
(
self
,
outputs
):
def
_do_output_past
(
self
,
outputs
):
has_output_past
=
hasattr
(
self
.
config
,
'
output_past
'
)
and
self
.
config
.
output_past
has_output_past
=
hasattr
(
self
.
config
,
"
output_past
"
)
and
self
.
config
.
output_past
has_mem_len
=
hasattr
(
self
.
config
,
'
mem_len
'
)
and
self
.
config
.
mem_len
has_mem_len
=
hasattr
(
self
.
config
,
"
mem_len
"
)
and
self
.
config
.
mem_len
if
has_output_past
and
not
has_mem_len
and
len
(
outputs
)
>
1
:
if
has_output_past
and
not
has_mem_len
and
len
(
outputs
)
>
1
:
return
True
return
True
...
...
src/transformers/modeling_xlnet.py
View file @
fc84bd52
...
@@ -1031,8 +1031,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1031,8 +1031,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
inputs
=
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
inputs
=
{
"input_ids"
:
input_ids
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
# if past is defined in model kwargs then use it for faster decoding
# if past is defined in model kwargs then use it for faster decoding
if
'
past
'
in
model_kwargs
and
model_kwargs
[
'
past
'
]:
if
"
past
"
in
model_kwargs
and
model_kwargs
[
"
past
"
]:
inputs
[
'
mems
'
]
=
model_kwargs
[
'
past
'
]
inputs
[
"
mems
"
]
=
model_kwargs
[
"
past
"
]
return
inputs
return
inputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment