Commit 24ed0b93 authored by thomwolf's avatar thomwolf
Browse files

updating run_xlnet_classifier

parent f6081f22
......@@ -124,3 +124,6 @@ tensorflow_code
# Models
models
proc_data
# examples
examples/runs
\ No newline at end of file
......@@ -54,91 +54,58 @@ def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--xlnet_model", default="xlnet-large-cased", type=str,
help="XLNet pre-trained model: currently only xlnet-large-cased.")
parser.add_argument("--task_name",
default=None,
type=str,
required=True,
parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train",
action='store_true',
# training
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval",
action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=8,
type=int,
help="Total batch size for eval.")
parser.add_argument("--learning_rate",
default=5e-5,
type=float,
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",
default=3.0,
type=float,
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--overwrite_output_dir',
action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
parser.add_argument("--train_batch_size", default=32, type=int,
help="Total batch size for training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16',
action='store_true',
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0,
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
# evaluation
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.")
# Model
parser.add_argument("--xlnet_model", default="xlnet-large-cased", type=str,
help="XLNet pre-trained model: currently only xlnet-large-cased.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
# task specific
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
# Misc
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
......@@ -306,7 +273,7 @@ def main():
input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
......@@ -420,7 +387,7 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
# create eval loss and other metric required by the task
if output_mode == "classification":
......@@ -501,7 +468,7 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
......
This diff is collapsed.
......@@ -606,7 +606,7 @@ class BertPreTrainedModel(nn.Module):
))
self.config = config
def init_bert_weights(self, module):
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
......@@ -823,7 +823,7 @@ class BertModel(BertPreTrainedModel):
self.encoder = BertEncoder(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
......@@ -951,7 +951,7 @@ class BertForPreTraining(BertPreTrainedModel):
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask,
......@@ -1030,7 +1030,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask,
......@@ -1105,7 +1105,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask,
......@@ -1184,7 +1184,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
keep_multihead_output=keep_multihead_output)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
......@@ -1261,7 +1261,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
keep_multihead_output=keep_multihead_output)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
......@@ -1343,7 +1343,7 @@ class BertForTokenClassification(BertPreTrainedModel):
keep_multihead_output=keep_multihead_output)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
......@@ -1428,7 +1428,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, head_mask=None):
......
......@@ -633,7 +633,7 @@ class XLNetPreTrainedModel(nn.Module):
))
self.config = config
def init_xlnet_weights(self, module):
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
......@@ -904,14 +904,14 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, inp_k, seg_id=None, input_mask=None,
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
......@@ -945,8 +945,8 @@ class XLNetModel(XLNetPreTrainedModel):
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
inp_k = inp_k.transpose(0, 1).contiguous()
seg_id = seg_id.transpose(0, 1).contiguous() if seg_id is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None
......@@ -969,11 +969,11 @@ class XLNetModel(XLNetPreTrainedModel):
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None]
elif input_mask is None and perm_mask is not None:
if attention_mask is not None and perm_mask is not None:
data_mask = attention_mask[None] + perm_mask
elif attention_mask is not None and perm_mask is None:
data_mask = attention_mask[None]
elif attention_mask is None and perm_mask is not None:
data_mask = perm_mask
else:
data_mask = None
......@@ -1011,13 +1011,13 @@ class XLNetModel(XLNetPreTrainedModel):
output_g = None
##### Segment embedding
if seg_id is not None:
# Convert `seg_id` to one-hot `seg_mat`
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
cat_ids = torch.cat([mem_pad, seg_id], dim=0)
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (seg_id[:, None] != cat_ids[None, :]).long()
seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
else:
seg_mat = None
......@@ -1076,8 +1076,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
......@@ -1112,14 +1112,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, attention_mask)
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
......@@ -1134,7 +1134,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# Tie weights
self.apply(self.init_xlnet_weights)
self.apply(self.init_weights)
self.tie_weights()
def tie_weights(self):
......@@ -1142,14 +1142,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"""
self.lm_loss.weight = self.transformer.word_embedding.weight
def forward(self, inp_k, seg_id=None, input_mask=None,
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
target=None, output_all_encoded_layers=True, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
attention_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
......@@ -1171,7 +1171,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
output, hidden_states, new_mems = self.transformer(inp_k, seg_id, input_mask,
output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
......@@ -1200,7 +1200,7 @@ class XLNetSequenceSummary(nn.Module):
super(XLNetSequenceSummary, self).__init__()
self.summary_type = summary_type
if use_proj:
self.summary = nn.Linear(config.hidden_size, num_labels)
self.summary = nn.Linear(config.d_model, config.d_model)
else:
self.summary = None
if summary_type == 'attn':
......@@ -1211,19 +1211,20 @@ class XLNetSequenceSummary(nn.Module):
self.dropout = nn.Dropout(config.dropout)
self.activation = nn.Tanh()
def forward(self, hidden_states, input_mask=None):
def forward(self, hidden_states):
""" hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer."""
if self.summary_type == 'last':
output = hidden_states[-1]
output = hidden_states[:, -1]
elif self.summary_type == 'first':
output = hidden_states[0]
output = hidden_states[:, 0]
elif self.summary_type == 'mean':
output = hidden_states.mean(dim=0)
output = hidden_states.mean(dim=1)
elif summary_type == 'attn':
raise NotImplementedError
output = self.summary(output)
output = self.dropout(output)
output = self.activation(output)
output = self.dropout(output)
return output
......@@ -1240,8 +1241,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
attention_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
......@@ -1277,14 +1278,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, attention_mask)
```
"""
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
......@@ -1302,17 +1303,17 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
use_proj=use_proj, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.loss_proj = nn.Linear(config.d_model, num_classes if not is_regression else 1)
self.apply(self.init_bert_weights)
self.loss_proj = nn.Linear(config.d_model, num_labels if not is_regression else 1)
self.apply(self.init_weights)
def forward(self, inp_k, seg_id=None, input_mask=None,
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
target=None, output_all_encoded_layers=True, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
attention_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
......@@ -1331,7 +1332,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention.
Set to None during finetuning.
"""
output, _, new_mems = self.transformer(inp_k, seg_id, input_mask,
output, _, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
......@@ -1356,3 +1357,96 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
# if self.output_attentions:
return logits, new_mems
# return all_attentions, encoded_layers, pooled_output
class XLNetForQuestionAnswering(XLNetPreTrainedModel):
"""XLNet model for Question Answering (span extraction).
This module is composed of the XLNet model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Params:
`config`: a XLNetConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLNet paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = XLNetConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = XLNetForQuestionAnswering(config)
start_logits, end_logits = model(input_ids, token_type_ids, attention_mask)
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions
self.transformer = XLNetModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
start_positions=None, end_positions=None,
output_all_encoded_layers=True, head_mask=None):
output, _, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
logits = self.qa_outputs(output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
elif self.output_attentions:
return all_attentions, start_logits, end_logits
return start_logits, end_logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment