Unverified Commit b01483fa authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Truncate max length if needed in all examples (#10034)

parent 45aaf5f7
......@@ -303,6 +303,22 @@ def main():
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length is None:
max_seq_length = tokenizer.model_max_length
if max_seq_length > 1024:
logger.warn(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length = 1024
else:
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
if data_args.line_by_line:
# When using line_by_line, we just tokenize each nonempty line.
padding = "max_length" if data_args.pad_to_max_length else False
......@@ -314,7 +330,7 @@ def main():
examples["text"],
padding=padding,
truncation=True,
max_length=data_args.max_seq_length,
max_length=max_seq_length,
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
# receives the `special_tokens_mask`.
return_special_tokens_mask=True,
......@@ -342,22 +358,6 @@ def main():
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_seq_length is None:
max_seq_length = tokenizer.model_max_length
if max_seq_length > 1024:
logger.warn(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length = 1024
else:
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
......
......@@ -300,6 +300,13 @@ def main():
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
if data_args.line_by_line:
# When using line_by_line, we just tokenize each nonempty line.
padding = "max_length" if data_args.pad_to_max_length else False
......@@ -307,7 +314,7 @@ def main():
def tokenize_function(examples):
# Remove empty lines
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length)
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=max_seq_length)
tokenized_datasets = datasets.map(
tokenize_function,
......@@ -329,13 +336,6 @@ def main():
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
......
......@@ -286,6 +286,22 @@ def main():
context_name = "sent1"
question_header_name = "sent2"
if data_args.max_seq_length is None:
max_seq_length = tokenizer.model_max_length
if max_seq_length > 1024:
logger.warn(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length = 1024
else:
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Preprocessing the datasets.
def preprocess_function(examples):
first_sentences = [[context] * 4 for context in examples[context_name]]
......@@ -303,7 +319,7 @@ def main():
first_sentences,
second_sentences,
truncation=True,
max_length=data_args.max_seq_length,
max_length=max_seq_length,
padding="max_length" if data_args.pad_to_max_length else False,
)
# Un-flatten
......
......@@ -277,6 +277,13 @@ def main():
# Padding side determines if we do (question|context) or (context|question).
pad_on_right = tokenizer.padding_side == "right"
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Training preprocessing
def prepare_train_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
......@@ -286,7 +293,7 @@ def main():
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length,
max_length=max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
......@@ -368,7 +375,7 @@ def main():
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length,
max_length=max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
......
......@@ -267,6 +267,13 @@ def main():
# Padding side determines if we do (question|context) or (context|question).
pad_on_right = tokenizer.padding_side == "right"
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Training preprocessing
def prepare_train_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
......@@ -276,7 +283,7 @@ def main():
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length,
max_length=max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
......@@ -381,7 +388,7 @@ def main():
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=data_args.max_seq_length,
max_length=max_seq_length,
stride=data_args.doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
......
......@@ -334,12 +334,19 @@ def main():
elif data_args.task_name is None and not is_regression:
label_to_id = {v: i for i, v in enumerate(label_list)}
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=padding, max_length=data_args.max_seq_length, truncation=True)
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
# Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment