Commit 9b312f9d authored by erenup's avatar erenup
Browse files

initial version for roberta squad

parent 40ed7172
...@@ -39,6 +39,7 @@ from tqdm import tqdm, trange ...@@ -39,6 +39,7 @@ from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig, from transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer, BertForQuestionAnswering, BertTokenizer,
RobertaForQuestionAnswering, RobertaTokenizer, RobertaConfig,
XLMConfig, XLMForQuestionAnswering, XLMConfig, XLMForQuestionAnswering,
XLMTokenizer, XLNetConfig, XLMTokenizer, XLNetConfig,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
...@@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e ...@@ -53,10 +54,11 @@ from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_e
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) for conf in (BertConfig, RobertaConfig, XLNetConfig, XLMConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'roberta': (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer), 'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
...@@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer): ...@@ -141,13 +143,11 @@ def train(args, train_dataset, model, tokenizer):
inputs = { inputs = {
'input_ids': batch[0], 'input_ids': batch[0],
'attention_mask': batch[1], 'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
'start_positions': batch[3], 'start_positions': batch[3],
'end_positions': batch[4] 'end_positions': batch[4],
} }
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]}) inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
...@@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -241,12 +241,9 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad(): with torch.no_grad():
inputs = { inputs = {
'input_ids': batch[0], 'input_ids': batch[0],
'attention_mask': batch[1] 'attention_mask': batch[1],
'token_type_ids': None if args.model_type in ['xlm', 'roberta', 'distilbert'] else batch[2],
} }
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3] example_indices = batch[3]
# XLNet and XLM use more arguments for their predictions # XLNet and XLM use more arguments for their predictions
...@@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -311,7 +308,7 @@ def evaluate(args, model, tokenizer, prefix=""):
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size, predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file, args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold) args.version_2_with_negative, args.null_score_diff_threshold, tokenizer)
# Compute the F1 and exact scores. # Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions) results = squad_evaluate(examples, predictions)
......
...@@ -99,7 +99,7 @@ if is_torch_available(): ...@@ -99,7 +99,7 @@ if is_torch_available():
XLM_PRETRAINED_MODEL_ARCHIVE_MAP) XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, from .modeling_roberta import (RobertaForMaskedLM, RobertaModel,
RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForSequenceClassification, RobertaForMultipleChoice,
RobertaForTokenClassification, RobertaForTokenClassification, RobertaForQuestionAnswering,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel, from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
......
...@@ -377,7 +377,8 @@ def compute_predictions_logits( ...@@ -377,7 +377,8 @@ def compute_predictions_logits(
output_null_log_odds_file, output_null_log_odds_file,
verbose_logging, verbose_logging,
version_2_with_negative, version_2_with_negative,
null_score_diff_threshold null_score_diff_threshold,
tokenizer,
): ):
"""Write final predictions to the json file and log-odds of null if needed.""" """Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing predictions to: %s" % (output_prediction_file))
...@@ -474,11 +475,14 @@ def compute_predictions_logits( ...@@ -474,11 +475,14 @@ def compute_predictions_logits(
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off. tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "") # tok_text = " ".join(tok_tokens)
#
# # De-tokenize WordPieces that have been split off.
# tok_text = tok_text.replace(" ##", "")
# tok_text = tok_text.replace("##", "")
# Clean whitespace # Clean whitespace
tok_text = tok_text.strip() tok_text = tok_text.strip()
......
...@@ -140,7 +140,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -140,7 +140,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_to_orig_index.append(i) tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token) all_doc_tokens.append(sub_token)
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position] tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1: if example.end_position < len(example.doc_tokens) - 1:
...@@ -155,7 +154,8 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -155,7 +154,8 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
spans = [] spans = []
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
span_doc_tokens = all_doc_tokens span_doc_tokens = all_doc_tokens
......
...@@ -555,3 +555,89 @@ class RobertaClassificationHead(nn.Module): ...@@ -555,3 +555,89 @@ class RobertaClassificationHead(nn.Module):
x = self.dropout(x) x = self.dropout(x)
x = self.out_proj(x) x = self.out_proj(x)
return x return x
@add_start_docstrings("""Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForQuestionAnswering(BertPreTrainedModel):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
start_positions=None, end_positions=None):
outputs = self.roberta(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment