Unverified Commit 4ed0fa36 authored by FilipposVentirozos's avatar FilipposVentirozos Committed by GitHub
Browse files

Fix pytorch seq2seq qa (#19258)



* fixed typo for SQuAD

* Fixed the preprocess_validation_function function for the labels to reflect the remaining truncated instances

* Rolled back the trainer_seq2seq_qa.py for UnboundLocalError: local variable 'metrics' referenced before assignment
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c60381e9
...@@ -115,7 +115,7 @@ python run_seq2seq_qa.py \ ...@@ -115,7 +115,7 @@ python run_seq2seq_qa.py \
--dataset_name squad_v2 \ --dataset_name squad_v2 \
--context_column context \ --context_column context \
--question_column question \ --question_column question \
--answer_column answer \ --answer_column answers \
--do_train \ --do_train \
--do_eval \ --do_eval \
--per_device_train_batch_size 12 \ --per_device_train_batch_size 12 \
......
...@@ -484,13 +484,19 @@ def main(): ...@@ -484,13 +484,19 @@ def main():
max_length=max_seq_length, max_length=max_seq_length,
padding=padding, padding=padding,
truncation=True, truncation=True,
return_offsets_mapping=True,
return_overflowing_tokens=True, return_overflowing_tokens=True,
return_offsets_mapping=True,
) )
# Tokenize targets with the `text_target` keyword argument # Tokenize targets with the `text_target` keyword argument
labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True) labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
# Since one example might give us several features if it has a long context, we need a map from a feature to # Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that. # its corresponding example. This key gives us just that.
sample_mapping = model_inputs.pop("overflow_to_sample_mapping") sample_mapping = model_inputs.pop("overflow_to_sample_mapping")
...@@ -498,21 +504,16 @@ def main(): ...@@ -498,21 +504,16 @@ def main():
# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
# corresponding example_id and we will store the offset mappings. # corresponding example_id and we will store the offset mappings.
model_inputs["example_id"] = [] model_inputs["example_id"] = []
# Augment the overflowing tokens to the labels
labels_out = []
for i in range(len(model_inputs["input_ids"])): for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text. # One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i] sample_index = sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index]) model_inputs["example_id"].append(examples["id"][sample_index])
labels_out.append(labels["input_ids"][sample_index])
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore model_inputs["labels"] = labels_out
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs return model_inputs
if training_args.do_train: if training_args.do_train:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment