Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
28931f81
Unverified
Commit
28931f81
authored
Jul 28, 2020
by
Sylvain Gugger
Committed by
GitHub
Jul 28, 2020
Browse files
Fix #6092 (#6093)
* Fix #6092 * Format
parent
5e97c829
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
src/transformers/data/data_collator.py
src/transformers/data/data_collator.py
+4
-3
No files found.
src/transformers/data/data_collator.py
View file @
28931f81
...
...
@@ -5,6 +5,7 @@ import torch
from
torch.nn.utils.rnn
import
pad_sequence
from
..tokenization_utils
import
PreTrainedTokenizer
from
..tokenization_utils_base
import
BatchEncoding
InputDataClass
=
NewType
(
"InputDataClass"
,
Any
)
...
...
@@ -33,7 +34,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
if
not
isinstance
(
features
[
0
],
dict
):
if
not
isinstance
(
features
[
0
],
(
dict
,
BatchEncoding
)
):
features
=
[
vars
(
f
)
for
f
in
features
]
first
=
features
[
0
]
...
...
@@ -78,7 +79,7 @@ class DataCollatorForLanguageModeling:
mlm_probability
:
float
=
0.15
def
__call__
(
self
,
examples
:
List
[
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
examples
[
0
],
dict
):
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncoding
)
):
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
batch
=
self
.
_tensorize_batch
(
examples
)
if
self
.
mlm
:
...
...
@@ -151,7 +152,7 @@ class DataCollatorForPermutationLanguageModeling:
max_span_length
:
int
=
5
# maximum length of a span of masked tokens
def
__call__
(
self
,
examples
:
List
[
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
examples
[
0
],
dict
):
if
isinstance
(
examples
[
0
],
(
dict
,
BatchEncoding
)
):
examples
=
[
e
[
"input_ids"
]
for
e
in
examples
]
batch
=
self
.
_tensorize_batch
(
examples
)
inputs
,
perm_mask
,
target_mapping
,
labels
=
self
.
mask_tokens
(
batch
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment