Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b01483fa
Unverified
Commit
b01483fa
authored
Feb 08, 2021
by
Sylvain Gugger
Committed by
GitHub
Feb 08, 2021
Browse files
Truncate max length if needed in all examples (#10034)
parent
45aaf5f7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
68 additions
and
31 deletions
+68
-31
examples/language-modeling/run_mlm.py
examples/language-modeling/run_mlm.py
+17
-17
examples/language-modeling/run_plm.py
examples/language-modeling/run_plm.py
+8
-8
examples/multiple-choice/run_swag.py
examples/multiple-choice/run_swag.py
+17
-1
examples/question-answering/run_qa.py
examples/question-answering/run_qa.py
+9
-2
examples/question-answering/run_qa_beam_search.py
examples/question-answering/run_qa_beam_search.py
+9
-2
examples/text-classification/run_glue.py
examples/text-classification/run_glue.py
+8
-1
No files found.
examples/language-modeling/run_mlm.py
View file @
b01483fa
...
@@ -303,6 +303,22 @@ def main():
...
@@ -303,6 +303,22 @@ def main():
column_names
=
datasets
[
"validation"
].
column_names
column_names
=
datasets
[
"validation"
].
column_names
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
if
data_args
.
max_seq_length
is
None
:
max_seq_length
=
tokenizer
.
model_max_length
if
max_seq_length
>
1024
:
logger
.
warn
(
f
"The tokenizer picked seems to have a very large `model_max_length` (
{
tokenizer
.
model_max_length
}
). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length
=
1024
else
:
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warn
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
f
"model (
{
tokenizer
.
model_max_length
}
). Using max_seq_length=
{
tokenizer
.
model_max_length
}
."
)
max_seq_length
=
min
(
data_args
.
max_seq_length
,
tokenizer
.
model_max_length
)
if
data_args
.
line_by_line
:
if
data_args
.
line_by_line
:
# When using line_by_line, we just tokenize each nonempty line.
# When using line_by_line, we just tokenize each nonempty line.
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
...
@@ -314,7 +330,7 @@ def main():
...
@@ -314,7 +330,7 @@ def main():
examples
[
"text"
],
examples
[
"text"
],
padding
=
padding
,
padding
=
padding
,
truncation
=
True
,
truncation
=
True
,
max_length
=
data_args
.
max_seq_length
,
max_length
=
max_seq_length
,
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
# receives the `special_tokens_mask`.
# receives the `special_tokens_mask`.
return_special_tokens_mask
=
True
,
return_special_tokens_mask
=
True
,
...
@@ -342,22 +358,6 @@ def main():
...
@@ -342,22 +358,6 @@ def main():
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
)
)
if
data_args
.
max_seq_length
is
None
:
max_seq_length
=
tokenizer
.
model_max_length
if
max_seq_length
>
1024
:
logger
.
warn
(
f
"The tokenizer picked seems to have a very large `model_max_length` (
{
tokenizer
.
model_max_length
}
). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length
=
1024
else
:
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warn
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
f
"model (
{
tokenizer
.
model_max_length
}
). Using max_seq_length=
{
tokenizer
.
model_max_length
}
."
)
max_seq_length
=
min
(
data_args
.
max_seq_length
,
tokenizer
.
model_max_length
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
# max_seq_length.
def
group_texts
(
examples
):
def
group_texts
(
examples
):
...
...
examples/language-modeling/run_plm.py
View file @
b01483fa
...
@@ -300,6 +300,13 @@ def main():
...
@@ -300,6 +300,13 @@ def main():
column_names
=
datasets
[
"validation"
].
column_names
column_names
=
datasets
[
"validation"
].
column_names
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warn
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
f
"model (
{
tokenizer
.
model_max_length
}
). Using max_seq_length=
{
tokenizer
.
model_max_length
}
."
)
max_seq_length
=
min
(
data_args
.
max_seq_length
,
tokenizer
.
model_max_length
)
if
data_args
.
line_by_line
:
if
data_args
.
line_by_line
:
# When using line_by_line, we just tokenize each nonempty line.
# When using line_by_line, we just tokenize each nonempty line.
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
...
@@ -307,7 +314,7 @@ def main():
...
@@ -307,7 +314,7 @@ def main():
def
tokenize_function
(
examples
):
def
tokenize_function
(
examples
):
# Remove empty lines
# Remove empty lines
examples
[
"text"
]
=
[
line
for
line
in
examples
[
"text"
]
if
len
(
line
)
>
0
and
not
line
.
isspace
()]
examples
[
"text"
]
=
[
line
for
line
in
examples
[
"text"
]
if
len
(
line
)
>
0
and
not
line
.
isspace
()]
return
tokenizer
(
examples
[
"text"
],
padding
=
padding
,
truncation
=
True
,
max_length
=
data_args
.
max_seq_length
)
return
tokenizer
(
examples
[
"text"
],
padding
=
padding
,
truncation
=
True
,
max_length
=
max_seq_length
)
tokenized_datasets
=
datasets
.
map
(
tokenized_datasets
=
datasets
.
map
(
tokenize_function
,
tokenize_function
,
...
@@ -329,13 +336,6 @@ def main():
...
@@ -329,13 +336,6 @@ def main():
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
)
)
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warn
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
f
"model (
{
tokenizer
.
model_max_length
}
). Using max_seq_length=
{
tokenizer
.
model_max_length
}
."
)
max_seq_length
=
min
(
data_args
.
max_seq_length
,
tokenizer
.
model_max_length
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
# max_seq_length.
def
group_texts
(
examples
):
def
group_texts
(
examples
):
...
...
examples/multiple-choice/run_swag.py
View file @
b01483fa
...
@@ -286,6 +286,22 @@ def main():
...
@@ -286,6 +286,22 @@ def main():
context_name
=
"sent1"
context_name
=
"sent1"
question_header_name
=
"sent2"
question_header_name
=
"sent2"
if
data_args
.
max_seq_length
is
None
:
max_seq_length
=
tokenizer
.
model_max_length
if
max_seq_length
>
1024
:
logger
.
warn
(
f
"The tokenizer picked seems to have a very large `model_max_length` (
{
tokenizer
.
model_max_length
}
). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length
=
1024
else
:
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warn
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
f
"model (
{
tokenizer
.
model_max_length
}
). Using max_seq_length=
{
tokenizer
.
model_max_length
}
."
)
max_seq_length
=
min
(
data_args
.
max_seq_length
,
tokenizer
.
model_max_length
)
# Preprocessing the datasets.
# Preprocessing the datasets.
def
preprocess_function
(
examples
):
def
preprocess_function
(
examples
):
first_sentences
=
[[
context
]
*
4
for
context
in
examples
[
context_name
]]
first_sentences
=
[[
context
]
*
4
for
context
in
examples
[
context_name
]]
...
@@ -303,7 +319,7 @@ def main():
...
@@ -303,7 +319,7 @@ def main():
first_sentences
,
first_sentences
,
second_sentences
,
second_sentences
,
truncation
=
True
,
truncation
=
True
,
max_length
=
data_args
.
max_seq_length
,
max_length
=
max_seq_length
,
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
,
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
,
)
)
# Un-flatten
# Un-flatten
...
...
examples/question-answering/run_qa.py
View file @
b01483fa
...
@@ -277,6 +277,13 @@ def main():
...
@@ -277,6 +277,13 @@ def main():
# Padding side determines if we do (question|context) or (context|question).
# Padding side determines if we do (question|context) or (context|question).
pad_on_right
=
tokenizer
.
padding_side
==
"right"
pad_on_right
=
tokenizer
.
padding_side
==
"right"
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warn
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
f
"model (
{
tokenizer
.
model_max_length
}
). Using max_seq_length=
{
tokenizer
.
model_max_length
}
."
)
max_seq_length
=
min
(
data_args
.
max_seq_length
,
tokenizer
.
model_max_length
)
# Training preprocessing
# Training preprocessing
def
prepare_train_features
(
examples
):
def
prepare_train_features
(
examples
):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
...
@@ -286,7 +293,7 @@ def main():
...
@@ -286,7 +293,7 @@ def main():
examples
[
question_column_name
if
pad_on_right
else
context_column_name
],
examples
[
question_column_name
if
pad_on_right
else
context_column_name
],
examples
[
context_column_name
if
pad_on_right
else
question_column_name
],
examples
[
context_column_name
if
pad_on_right
else
question_column_name
],
truncation
=
"only_second"
if
pad_on_right
else
"only_first"
,
truncation
=
"only_second"
if
pad_on_right
else
"only_first"
,
max_length
=
data_args
.
max_seq_length
,
max_length
=
max_seq_length
,
stride
=
data_args
.
doc_stride
,
stride
=
data_args
.
doc_stride
,
return_overflowing_tokens
=
True
,
return_overflowing_tokens
=
True
,
return_offsets_mapping
=
True
,
return_offsets_mapping
=
True
,
...
@@ -368,7 +375,7 @@ def main():
...
@@ -368,7 +375,7 @@ def main():
examples
[
question_column_name
if
pad_on_right
else
context_column_name
],
examples
[
question_column_name
if
pad_on_right
else
context_column_name
],
examples
[
context_column_name
if
pad_on_right
else
question_column_name
],
examples
[
context_column_name
if
pad_on_right
else
question_column_name
],
truncation
=
"only_second"
if
pad_on_right
else
"only_first"
,
truncation
=
"only_second"
if
pad_on_right
else
"only_first"
,
max_length
=
data_args
.
max_seq_length
,
max_length
=
max_seq_length
,
stride
=
data_args
.
doc_stride
,
stride
=
data_args
.
doc_stride
,
return_overflowing_tokens
=
True
,
return_overflowing_tokens
=
True
,
return_offsets_mapping
=
True
,
return_offsets_mapping
=
True
,
...
...
examples/question-answering/run_qa_beam_search.py
View file @
b01483fa
...
@@ -267,6 +267,13 @@ def main():
...
@@ -267,6 +267,13 @@ def main():
# Padding side determines if we do (question|context) or (context|question).
# Padding side determines if we do (question|context) or (context|question).
pad_on_right
=
tokenizer
.
padding_side
==
"right"
pad_on_right
=
tokenizer
.
padding_side
==
"right"
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warn
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
f
"model (
{
tokenizer
.
model_max_length
}
). Using max_seq_length=
{
tokenizer
.
model_max_length
}
."
)
max_seq_length
=
min
(
data_args
.
max_seq_length
,
tokenizer
.
model_max_length
)
# Training preprocessing
# Training preprocessing
def
prepare_train_features
(
examples
):
def
prepare_train_features
(
examples
):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
...
@@ -276,7 +283,7 @@ def main():
...
@@ -276,7 +283,7 @@ def main():
examples
[
question_column_name
if
pad_on_right
else
context_column_name
],
examples
[
question_column_name
if
pad_on_right
else
context_column_name
],
examples
[
context_column_name
if
pad_on_right
else
question_column_name
],
examples
[
context_column_name
if
pad_on_right
else
question_column_name
],
truncation
=
"only_second"
if
pad_on_right
else
"only_first"
,
truncation
=
"only_second"
if
pad_on_right
else
"only_first"
,
max_length
=
data_args
.
max_seq_length
,
max_length
=
max_seq_length
,
stride
=
data_args
.
doc_stride
,
stride
=
data_args
.
doc_stride
,
return_overflowing_tokens
=
True
,
return_overflowing_tokens
=
True
,
return_offsets_mapping
=
True
,
return_offsets_mapping
=
True
,
...
@@ -381,7 +388,7 @@ def main():
...
@@ -381,7 +388,7 @@ def main():
examples
[
question_column_name
if
pad_on_right
else
context_column_name
],
examples
[
question_column_name
if
pad_on_right
else
context_column_name
],
examples
[
context_column_name
if
pad_on_right
else
question_column_name
],
examples
[
context_column_name
if
pad_on_right
else
question_column_name
],
truncation
=
"only_second"
if
pad_on_right
else
"only_first"
,
truncation
=
"only_second"
if
pad_on_right
else
"only_first"
,
max_length
=
data_args
.
max_seq_length
,
max_length
=
max_seq_length
,
stride
=
data_args
.
doc_stride
,
stride
=
data_args
.
doc_stride
,
return_overflowing_tokens
=
True
,
return_overflowing_tokens
=
True
,
return_offsets_mapping
=
True
,
return_offsets_mapping
=
True
,
...
...
examples/text-classification/run_glue.py
View file @
b01483fa
...
@@ -334,12 +334,19 @@ def main():
...
@@ -334,12 +334,19 @@ def main():
elif
data_args
.
task_name
is
None
and
not
is_regression
:
elif
data_args
.
task_name
is
None
and
not
is_regression
:
label_to_id
=
{
v
:
i
for
i
,
v
in
enumerate
(
label_list
)}
label_to_id
=
{
v
:
i
for
i
,
v
in
enumerate
(
label_list
)}
if
data_args
.
max_seq_length
>
tokenizer
.
model_max_length
:
logger
.
warn
(
f
"The max_seq_length passed (
{
data_args
.
max_seq_length
}
) is larger than the maximum length for the"
f
"model (
{
tokenizer
.
model_max_length
}
). Using max_seq_length=
{
tokenizer
.
model_max_length
}
."
)
max_seq_length
=
min
(
data_args
.
max_seq_length
,
tokenizer
.
model_max_length
)
def
preprocess_function
(
examples
):
def
preprocess_function
(
examples
):
# Tokenize the texts
# Tokenize the texts
args
=
(
args
=
(
(
examples
[
sentence1_key
],)
if
sentence2_key
is
None
else
(
examples
[
sentence1_key
],
examples
[
sentence2_key
])
(
examples
[
sentence1_key
],)
if
sentence2_key
is
None
else
(
examples
[
sentence1_key
],
examples
[
sentence2_key
])
)
)
result
=
tokenizer
(
*
args
,
padding
=
padding
,
max_length
=
data_args
.
max_seq_length
,
truncation
=
True
)
result
=
tokenizer
(
*
args
,
padding
=
padding
,
max_length
=
max_seq_length
,
truncation
=
True
)
# Map labels to IDs (not necessary for GLUE tasks)
# Map labels to IDs (not necessary for GLUE tasks)
if
label_to_id
is
not
None
and
"label"
in
examples
:
if
label_to_id
is
not
None
and
"label"
in
examples
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment