Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d1d15d6f
Unverified
Commit
d1d15d6f
authored
Jul 27, 2020
by
Suraj Patil
Committed by
GitHub
Jul 27, 2020
Browse files
[examples (seq2seq)] fix preparing decoder_input_ids for T5 (#5994)
parent
3deffc1d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
3 deletions
+9
-3
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+9
-3
No files found.
examples/seq2seq/finetune.py
View file @
d1d15d6f
...
@@ -14,7 +14,7 @@ import torch
...
@@ -14,7 +14,7 @@ import torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
transformers
import
MBartTokenizer
,
get_linear_schedule_with_warmup
from
transformers
import
MBartTokenizer
,
T5ForConditionalGeneration
,
get_linear_schedule_with_warmup
try
:
try
:
...
@@ -131,8 +131,14 @@ class SummarizationModule(BaseTransformer):
...
@@ -131,8 +131,14 @@ class SummarizationModule(BaseTransformer):
def
_step
(
self
,
batch
:
dict
)
->
Tuple
:
def
_step
(
self
,
batch
:
dict
)
->
Tuple
:
pad_token_id
=
self
.
tokenizer
.
pad_token_id
pad_token_id
=
self
.
tokenizer
.
pad_token_id
source_ids
,
source_mask
,
target_ids
=
batch
[
"input_ids"
],
batch
[
"attention_mask"
],
batch
[
"decoder_input_ids"
]
source_ids
,
source_mask
,
target_ids
=
batch
[
"input_ids"
],
batch
[
"attention_mask"
],
batch
[
"decoder_input_ids"
]
decoder_input_ids
=
target_ids
[:,
:
-
1
].
contiguous
()
# Why this line?
lm_labels
=
target_ids
[:,
1
:].
clone
()
# why clone?
if
isinstance
(
self
.
model
,
T5ForConditionalGeneration
):
decoder_input_ids
=
self
.
model
.
_shift_right
(
target_ids
)
lm_labels
=
target_ids
else
:
decoder_input_ids
=
target_ids
[:,
:
-
1
].
contiguous
()
# Why this line?
lm_labels
=
target_ids
[:,
1
:].
clone
()
# why clone?
outputs
=
self
(
source_ids
,
attention_mask
=
source_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
outputs
=
self
(
source_ids
,
attention_mask
=
source_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
if
self
.
hparams
.
label_smoothing
==
0
:
if
self
.
hparams
.
label_smoothing
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment