Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d018622d
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "34a3c25a3068ab5cdbecb08ddf2866f1209fd2dd"
Unverified
Commit
d018622d
authored
Dec 15, 2020
by
Patrick von Platen
Committed by
GitHub
Dec 15, 2020
Browse files
correct mistake in order (#9134)
parent
80bdb9c3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
src/transformers/models/bart/modeling_bart.py
src/transformers/models/bart/modeling_bart.py
+4
-3
No files found.
src/transformers/models/bart/modeling_bart.py
View file @
d018622d
...
@@ -67,14 +67,15 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
...
@@ -67,14 +67,15 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
"""
"""
prev_output_tokens
=
input_ids
.
clone
()
prev_output_tokens
=
input_ids
.
clone
()
index_of_eos
=
(
input_ids
.
ne
(
pad_token_id
).
sum
(
dim
=
1
)
-
1
).
unsqueeze
(
-
1
)
prev_output_tokens
[:,
0
]
=
input_ids
.
gather
(
1
,
index_of_eos
).
squeeze
()
prev_output_tokens
[:,
1
:]
=
input_ids
[:,
:
-
1
]
assert
pad_token_id
is
not
None
,
"self.model.config.pad_token_id has to be defined."
assert
pad_token_id
is
not
None
,
"self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
# replace possible -100 values in labels by `pad_token_id`
prev_output_tokens
.
masked_fill_
(
prev_output_tokens
==
-
100
,
pad_token_id
)
prev_output_tokens
.
masked_fill_
(
prev_output_tokens
==
-
100
,
pad_token_id
)
index_of_eos
=
(
input_ids
.
ne
(
pad_token_id
).
sum
(
dim
=
1
)
-
1
).
unsqueeze
(
-
1
)
prev_output_tokens
[:,
0
]
=
input_ids
.
gather
(
1
,
index_of_eos
).
squeeze
()
prev_output_tokens
[:,
1
:]
=
input_ids
[:,
:
-
1
]
return
prev_output_tokens
return
prev_output_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment