Unverified Commit b01483fa authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Truncate max length if needed in all examples (#10034)

parent 45aaf5f7
...@@ -303,6 +303,22 @@ def main(): ...@@ -303,6 +303,22 @@ def main():
column_names = datasets["validation"].column_names column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length is None:
max_seq_length = tokenizer.model_max_length
if max_seq_length > 1024:
logger.warn(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length = 1024
else:
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
if data_args.line_by_line: if data_args.line_by_line:
# When using line_by_line, we just tokenize each nonempty line. # When using line_by_line, we just tokenize each nonempty line.
padding = "max_length" if data_args.pad_to_max_length else False padding = "max_length" if data_args.pad_to_max_length else False
...@@ -314,7 +330,7 @@ def main(): ...@@ -314,7 +330,7 @@ def main():
examples["text"], examples["text"],
padding=padding, padding=padding,
truncation=True, truncation=True,
max_length=data_args.max_seq_length, max_length=max_seq_length,
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
# receives the `special_tokens_mask`. # receives the `special_tokens_mask`.
return_special_tokens_mask=True, return_special_tokens_mask=True,
...@@ -342,22 +358,6 @@ def main(): ...@@ -342,22 +358,6 @@ def main():
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
if data_args.max_seq_length is None:
max_seq_length = tokenizer.model_max_length
if max_seq_length > 1024:
logger.warn(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length = 1024
else:
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of # Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length. # max_seq_length.
def group_texts(examples): def group_texts(examples):
......
...@@ -300,6 +300,13 @@ def main(): ...@@ -300,6 +300,13 @@ def main():
column_names = datasets["validation"].column_names column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0] text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
if data_args.line_by_line: if data_args.line_by_line:
# When using line_by_line, we just tokenize each nonempty line. # When using line_by_line, we just tokenize each nonempty line.
padding = "max_length" if data_args.pad_to_max_length else False padding = "max_length" if data_args.pad_to_max_length else False
...@@ -307,7 +314,7 @@ def main(): ...@@ -307,7 +314,7 @@ def main():
def tokenize_function(examples): def tokenize_function(examples):
# Remove empty lines # Remove empty lines
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length) return tokenizer(examples["text"], padding=padding, truncation=True, max_length=max_seq_length)
tokenized_datasets = datasets.map( tokenized_datasets = datasets.map(
tokenize_function, tokenize_function,
...@@ -329,13 +336,6 @@ def main(): ...@@ -329,13 +336,6 @@ def main():
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of # Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length. # max_seq_length.
def group_texts(examples): def group_texts(examples):
......
...@@ -286,6 +286,22 @@ def main(): ...@@ -286,6 +286,22 @@ def main():
context_name = "sent1" context_name = "sent1"
question_header_name = "sent2" question_header_name = "sent2"
if data_args.max_seq_length is None:
max_seq_length = tokenizer.model_max_length
if max_seq_length > 1024:
logger.warn(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length = 1024
else:
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Preprocessing the datasets. # Preprocessing the datasets.
def preprocess_function(examples): def preprocess_function(examples):
first_sentences = [[context] * 4 for context in examples[context_name]] first_sentences = [[context] * 4 for context in examples[context_name]]
...@@ -303,7 +319,7 @@ def main(): ...@@ -303,7 +319,7 @@ def main():
first_sentences, first_sentences,
second_sentences, second_sentences,
truncation=True, truncation=True,
max_length=data_args.max_seq_length, max_length=max_seq_length,
padding="max_length" if data_args.pad_to_max_length else False, padding="max_length" if data_args.pad_to_max_length else False,
) )
# Un-flatten # Un-flatten
......
...@@ -277,6 +277,13 @@ def main(): ...@@ -277,6 +277,13 @@ def main():
# Padding side determines if we do (question|context) or (context|question). # Padding side determines if we do (question|context) or (context|question).
pad_on_right = tokenizer.padding_side == "right" pad_on_right = tokenizer.padding_side == "right"
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Training preprocessing # Training preprocessing
def prepare_train_features(examples): def prepare_train_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
...@@ -286,7 +293,7 @@ def main(): ...@@ -286,7 +293,7 @@ def main():
examples[question_column_name if pad_on_right else context_column_name], examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name], examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first", truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length, max_length=max_seq_length,
stride=data_args.doc_stride, stride=data_args.doc_stride,
return_overflowing_tokens=True, return_overflowing_tokens=True,
return_offsets_mapping=True, return_offsets_mapping=True,
...@@ -368,7 +375,7 @@ def main(): ...@@ -368,7 +375,7 @@ def main():
examples[question_column_name if pad_on_right else context_column_name], examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name], examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first", truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length, max_length=max_seq_length,
stride=data_args.doc_stride, stride=data_args.doc_stride,
return_overflowing_tokens=True, return_overflowing_tokens=True,
return_offsets_mapping=True, return_offsets_mapping=True,
......
...@@ -267,6 +267,13 @@ def main(): ...@@ -267,6 +267,13 @@ def main():
# Padding side determines if we do (question|context) or (context|question). # Padding side determines if we do (question|context) or (context|question).
pad_on_right = tokenizer.padding_side == "right" pad_on_right = tokenizer.padding_side == "right"
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Training preprocessing # Training preprocessing
def prepare_train_features(examples): def prepare_train_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
...@@ -276,7 +283,7 @@ def main(): ...@@ -276,7 +283,7 @@ def main():
examples[question_column_name if pad_on_right else context_column_name], examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name], examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first", truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length, max_length=max_seq_length,
stride=data_args.doc_stride, stride=data_args.doc_stride,
return_overflowing_tokens=True, return_overflowing_tokens=True,
return_offsets_mapping=True, return_offsets_mapping=True,
...@@ -381,7 +388,7 @@ def main(): ...@@ -381,7 +388,7 @@ def main():
examples[question_column_name if pad_on_right else context_column_name], examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name], examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first", truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length, max_length=max_seq_length,
stride=data_args.doc_stride, stride=data_args.doc_stride,
return_overflowing_tokens=True, return_overflowing_tokens=True,
return_offsets_mapping=True, return_offsets_mapping=True,
......
...@@ -334,12 +334,19 @@ def main(): ...@@ -334,12 +334,19 @@ def main():
elif data_args.task_name is None and not is_regression: elif data_args.task_name is None and not is_regression:
label_to_id = {v: i for i, v in enumerate(label_list)} label_to_id = {v: i for i, v in enumerate(label_list)}
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
def preprocess_function(examples): def preprocess_function(examples):
# Tokenize the texts # Tokenize the texts
args = ( args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
) )
result = tokenizer(*args, padding=padding, max_length=data_args.max_seq_length, truncation=True) result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
# Map labels to IDs (not necessary for GLUE tasks) # Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples: if label_to_id is not None and "label" in examples:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment