Commit d4614729 authored by Lysandre's avatar Lysandre
Browse files

return for SQuAD [BLACKED]

parent f24a228a
...@@ -18,19 +18,20 @@ if is_tf_available(): ...@@ -18,19 +18,20 @@ if is_tf_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text): def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
"""Returns tokenized answer spans that better match the annotated answer.""" """Returns tokenized answer spans that better match the annotated answer."""
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
for new_start in range(input_start, input_end + 1): for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1): for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
if text_span == tok_answer_text: if text_span == tok_answer_text:
return (new_start, new_end) return (new_start, new_end)
return (input_start, input_end) return (input_start, input_end)
def _check_is_max_context(doc_spans, cur_span_index, position): def _check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token.""" """Check if this is the 'max context' doc span for the token."""
best_score = None best_score = None
...@@ -50,6 +51,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position): ...@@ -50,6 +51,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
def _new_check_is_max_context(doc_spans, cur_span_index, position): def _new_check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token.""" """Check if this is the 'max context' doc span for the token."""
# if len(doc_spans) == 1: # if len(doc_spans) == 1:
...@@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position): ...@@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
def _is_whitespace(c): def _is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True return True
return False return False
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_stride, max_query_length, is_training, def squad_convert_examples_to_features(
return_dataset=False): examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False
):
""" """
Converts a list of examples into a list of features that can be directly given as input to a model. Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...@@ -123,13 +127,12 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -123,13 +127,12 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position = example.end_position end_position = example.end_position
# If the answer cannot be found in the text, then skip this example. # If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)]) actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
continue continue
tok_to_orig_index = [] tok_to_orig_index = []
orig_to_tok_index = [] orig_to_tok_index = []
all_doc_tokens = [] all_doc_tokens = []
...@@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_to_orig_index.append(i) tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token) all_doc_tokens.append(sub_token)
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position] tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1: if example.end_position < len(example.doc_tokens) - 1:
...@@ -154,7 +156,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -154,7 +156,9 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
spans = [] spans = []
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) truncated_query = tokenizer.encode(
example.question_text, add_special_tokens=False, max_length=max_query_length
)
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
...@@ -168,15 +172,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -168,15 +172,18 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
return_overflowing_tokens=True, return_overflowing_tokens=True,
pad_to_max_length=True, pad_to_max_length=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first' truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
) )
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens) paragraph_len = min(
len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
)
if tokenizer.pad_token_id in encoded_dict['input_ids']: if tokenizer.pad_token_id in encoded_dict["input_ids"]:
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)] non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
else: else:
non_padded_ids = encoded_dict['input_ids'] non_padded_ids = encoded_dict["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
...@@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
for doc_span_index in range(len(spans)): for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]): for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = j if tokenizer.padding_side == "left" else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j index = (
j
if tokenizer.padding_side == "left"
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
)
spans[doc_span_index]["token_is_max_context"][index] = is_max_context spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans: for span in spans:
# Identify the position of the CLS token # Identify the position of the CLS token
cls_index = span['input_ids'].index(tokenizer.cls_token_id) cls_index = span["input_ids"].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...) # Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask = np.array(span['token_type_ids']) p_mask = np.array(span["token_type_ids"])
p_mask = np.minimum(p_mask, 1) p_mask = np.minimum(p_mask, 1)
...@@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Set the CLS index to '0' # Set the CLS index to '0'
p_mask[cls_index] = 0 p_mask[cls_index] = 0
span_is_impossible = example.is_impossible span_is_impossible = example.is_impossible
start_position = 0 start_position = 0
end_position = 0 end_position = 0
...@@ -251,51 +261,95 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -251,51 +261,95 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position = tok_start_position - doc_start + doc_offset start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
features.append(
features.append(SquadFeatures( SquadFeatures(
span['input_ids'], span["input_ids"],
span['attention_mask'], span["attention_mask"],
span['token_type_ids'], span["token_type_ids"],
cls_index, cls_index,
p_mask.tolist(), p_mask.tolist(),
example_index=example_index, example_index=example_index,
unique_id=unique_id, unique_id=unique_id,
paragraph_len=span['paragraph_len'], paragraph_len=span["paragraph_len"],
token_is_max_context=span["token_is_max_context"], token_is_max_context=span["token_is_max_context"],
tokens=span["tokens"], tokens=span["tokens"],
token_to_orig_map=span["token_to_orig_map"], token_to_orig_map=span["token_to_orig_map"],
start_position=start_position, start_position=start_position,
end_position=end_position end_position=end_position,
)) )
)
unique_id += 1 unique_id += 1
if return_dataset == 'pt': if return_dataset == "pt":
if not is_torch_available(): if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.") raise ImportError("Pytorch must be installed to return a pytorch dataset.")
# Convert to Tensors and build dataset # Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if not is_training: if not is_training:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_example_index, all_cls_index, all_p_mask) all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask
)
else: else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_start_positions, all_end_positions, all_input_ids,
all_cls_index, all_p_mask) all_attention_masks,
all_token_type_ids,
all_start_positions,
all_end_positions,
all_cls_index,
all_p_mask,
)
return features, dataset return features, dataset
elif return_dataset == "tf":
if not is_tf_available():
raise ImportError("TensorFlow must be installed to return a TensorFlow dataset.")
def gen():
for ex in features:
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
}, {
"start_position": ex.start_position,
"end_position": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
}
)
return tf.data.Dataset.from_generator(
gen,
(
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
{"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32},
),
(
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
{
"start_position": tf.TensorShape([]),
"end_position": tf.TensorShape([]),
"cls_index": tf.TensorShape([]),
"p_mask": tf.TensorShape([None]),
},
),
)
return features return features
...@@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor): ...@@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor):
Processor for the SQuAD data set. Processor for the SQuAD data set.
Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively. Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively.
""" """
train_file = None train_file = None
dev_file = None dev_file = None
def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False): def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
if not evaluate: if not evaluate:
answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8') answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
answer_start = tensor_dict['answers']['answer_start'][0].numpy() answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
answers = [] answers = []
else: else:
answers = [{ answers = [
"answer_start": start.numpy(), {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
"text": text.numpy().decode('utf-8') for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
} for start, text in zip(tensor_dict['answers']["answer_start"], tensor_dict['answers']["text"])] ]
answer = None answer = None
answer_start = None answer_start = None
return SquadExample( return SquadExample(
qas_id=tensor_dict['id'].numpy().decode("utf-8"), qas_id=tensor_dict["id"].numpy().decode("utf-8"),
question_text=tensor_dict['question'].numpy().decode('utf-8'), question_text=tensor_dict["question"].numpy().decode("utf-8"),
context_text=tensor_dict['context'].numpy().decode('utf-8'), context_text=tensor_dict["context"].numpy().decode("utf-8"),
answer_text=answer, answer_text=answer,
start_position_character=answer_start, start_position_character=answer_start,
title=tensor_dict['title'].numpy().decode('utf-8'), title=tensor_dict["title"].numpy().decode("utf-8"),
answers=answers answers=answers,
) )
def get_examples_from_dataset(self, dataset, evaluate=False): def get_examples_from_dataset(self, dataset, evaluate=False):
...@@ -379,7 +434,9 @@ class SquadProcessor(DataProcessor): ...@@ -379,7 +434,9 @@ class SquadProcessor(DataProcessor):
if self.train_file is None: if self.train_file is None:
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
with open(os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding='utf-8') as reader: with open(
os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
) as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
return self._create_examples(input_data, "train") return self._create_examples(input_data, "train")
...@@ -398,7 +455,9 @@ class SquadProcessor(DataProcessor): ...@@ -398,7 +455,9 @@ class SquadProcessor(DataProcessor):
if self.dev_file is None: if self.dev_file is None:
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
with open(os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding='utf-8') as reader: with open(
os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
) as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
return self._create_examples(input_data, "dev") return self._create_examples(input_data, "dev")
...@@ -406,7 +465,7 @@ class SquadProcessor(DataProcessor): ...@@ -406,7 +465,7 @@ class SquadProcessor(DataProcessor):
is_training = set_type == "train" is_training = set_type == "train"
examples = [] examples = []
for entry in tqdm(input_data): for entry in tqdm(input_data):
title = entry['title'] title = entry["title"]
for paragraph in entry["paragraphs"]: for paragraph in entry["paragraphs"]:
context_text = paragraph["context"] context_text = paragraph["context"]
for qa in paragraph["qas"]: for qa in paragraph["qas"]:
...@@ -424,8 +483,8 @@ class SquadProcessor(DataProcessor): ...@@ -424,8 +483,8 @@ class SquadProcessor(DataProcessor):
if not is_impossible: if not is_impossible:
if is_training: if is_training:
answer = qa["answers"][0] answer = qa["answers"][0]
answer_text = answer['text'] answer_text = answer["text"]
start_position_character = answer['answer_start'] start_position_character = answer["answer_start"]
else: else:
answers = qa["answers"] answers = qa["answers"]
...@@ -437,12 +496,13 @@ class SquadProcessor(DataProcessor): ...@@ -437,12 +496,13 @@ class SquadProcessor(DataProcessor):
start_position_character=start_position_character, start_position_character=start_position_character,
title=title, title=title,
is_impossible=is_impossible, is_impossible=is_impossible,
answers=answers answers=answers,
) )
examples.append(example) examples.append(example)
return examples return examples
class SquadV1Processor(SquadProcessor): class SquadV1Processor(SquadProcessor):
train_file = "train-v1.1.json" train_file = "train-v1.1.json"
dev_file = "dev-v1.1.json" dev_file = "dev-v1.1.json"
...@@ -468,7 +528,8 @@ class SquadExample(object): ...@@ -468,7 +528,8 @@ class SquadExample(object):
is_impossible: False by default, set to True if the example has no possible answer. is_impossible: False by default, set to True if the example has no possible answer.
""" """
def __init__(self, def __init__(
self,
qas_id, qas_id,
question_text, question_text,
context_text, context_text,
...@@ -476,7 +537,8 @@ class SquadExample(object): ...@@ -476,7 +537,8 @@ class SquadExample(object):
start_position_character, start_position_character,
title, title,
answers=[], answers=[],
is_impossible=False): is_impossible=False,
):
self.qas_id = qas_id self.qas_id = qas_id
self.question_text = question_text self.question_text = question_text
self.context_text = context_text self.context_text = context_text
...@@ -537,22 +599,21 @@ class SquadFeatures(object): ...@@ -537,22 +599,21 @@ class SquadFeatures(object):
end_position: end of the answer token index end_position: end of the answer token index
""" """
def __init__(self, def __init__(
self,
input_ids, input_ids,
attention_mask, attention_mask,
token_type_ids, token_type_ids,
cls_index, cls_index,
p_mask, p_mask,
example_index, example_index,
unique_id, unique_id,
paragraph_len, paragraph_len,
token_is_max_context, token_is_max_context,
tokens, tokens,
token_to_orig_map, token_to_orig_map,
start_position, start_position,
end_position end_position,
): ):
self.input_ids = input_ids self.input_ids = input_ids
self.attention_mask = attention_mask self.attention_mask = attention_mask
...@@ -580,6 +641,7 @@ class SquadResult(object): ...@@ -580,6 +641,7 @@ class SquadResult(object):
start_logits: The logits corresponding to the start of the answer start_logits: The logits corresponding to the start of the answer
end_logits: The logits corresponding to the end of the answer end_logits: The logits corresponding to the end of the answer
""" """
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None): def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
self.start_logits = start_logits self.start_logits = start_logits
self.end_logits = end_logits self.end_logits = end_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment