Commit 24ed0b93 authored by thomwolf's avatar thomwolf
Browse files

updating run_xlnet_classifier

parent f6081f22
...@@ -124,3 +124,6 @@ tensorflow_code ...@@ -124,3 +124,6 @@ tensorflow_code
# Models # Models
models models
proc_data proc_data
# examples
examples/runs
\ No newline at end of file
...@@ -54,91 +54,58 @@ def main(): ...@@ -54,91 +54,58 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--data_dir", parser.add_argument("--data_dir", default=None, type=str, required=True,
default=None,
type=str,
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.") help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--xlnet_model", default="xlnet-large-cased", type=str, parser.add_argument("--task_name", default=None, type=str, required=True,
help="XLNet pre-trained model: currently only xlnet-large-cased.")
parser.add_argument("--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train.") help="The name of the task to train.")
parser.add_argument("--output_dir", parser.add_argument("--output_dir", default=None, type=str, required=True,
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.") help="The output directory where the model predictions and checkpoints will be written.")
# training
## Other parameters parser.add_argument("--do_train", action='store_true',
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train",
action='store_true',
help="Whether to run training.") help="Whether to run training.")
parser.add_argument("--do_eval", parser.add_argument("--learning_rate", default=5e-5, type=float,
action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=8,
type=int,
help="Total batch size for eval.")
parser.add_argument("--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.") help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", parser.add_argument("--num_train_epochs", default=3.0, type=float,
default=3.0,
type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", parser.add_argument("--warmup_proportion", default=0.1, type=float,
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. " help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.") "E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda", parser.add_argument("--train_batch_size", default=32, type=int,
action='store_true', help="Total batch size for training.")
help="Whether not to use CUDA when available") parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
parser.add_argument('--overwrite_output_dir',
action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.") help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', parser.add_argument('--fp16', action='store_true',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit") help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale', parser.add_argument('--loss_scale', type=float, default=0,
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
# evaluation
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.")
# Model
parser.add_argument("--xlnet_model", default="xlnet-large-cased", type=str,
help="XLNet pre-trained model: currently only xlnet-large-cased.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
# task specific
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
# Misc
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
...@@ -306,7 +273,7 @@ def main(): ...@@ -306,7 +273,7 @@ def main():
input_ids, input_mask, segment_ids, label_ids = batch input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes # define a new function to compute loss values for both output_modes
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
if output_mode == "classification": if output_mode == "classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
...@@ -420,7 +387,7 @@ def main(): ...@@ -420,7 +387,7 @@ def main():
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
with torch.no_grad(): with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
# create eval loss and other metric required by the task # create eval loss and other metric required by the task
if output_mode == "classification": if output_mode == "classification":
...@@ -501,7 +468,7 @@ def main(): ...@@ -501,7 +468,7 @@ def main():
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
with torch.no_grad(): with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None) logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1)) tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
......
This diff is collapsed.
...@@ -606,7 +606,7 @@ class BertPreTrainedModel(nn.Module): ...@@ -606,7 +606,7 @@ class BertPreTrainedModel(nn.Module):
)) ))
self.config = config self.config = config
def init_bert_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
...@@ -823,7 +823,7 @@ class BertModel(BertPreTrainedModel): ...@@ -823,7 +823,7 @@ class BertModel(BertPreTrainedModel):
self.encoder = BertEncoder(config, output_attentions=output_attentions, self.encoder = BertEncoder(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
...@@ -951,7 +951,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -951,7 +951,7 @@ class BertForPreTraining(BertPreTrainedModel):
self.bert = BertModel(config, output_attentions=output_attentions, self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
...@@ -1030,7 +1030,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1030,7 +1030,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.bert = BertModel(config, output_attentions=output_attentions, self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
...@@ -1105,7 +1105,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1105,7 +1105,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.bert = BertModel(config, output_attentions=output_attentions, self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.cls = BertOnlyNSPHead(config) self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
...@@ -1184,7 +1184,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1184,7 +1184,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
...@@ -1261,7 +1261,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1261,7 +1261,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
...@@ -1343,7 +1343,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1343,7 +1343,7 @@ class BertForTokenClassification(BertPreTrainedModel):
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
...@@ -1428,7 +1428,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1428,7 +1428,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.bert = BertModel(config, output_attentions=output_attentions, self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, head_mask=None): end_positions=None, head_mask=None):
......
...@@ -633,7 +633,7 @@ class XLNetPreTrainedModel(nn.Module): ...@@ -633,7 +633,7 @@ class XLNetPreTrainedModel(nn.Module):
)) ))
self.config = config self.config = config
def init_xlnet_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
...@@ -904,14 +904,14 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -904,14 +904,14 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters())) pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb return pos_emb
def forward(self, inp_k, seg_id=None, input_mask=None, def forward(self, inp_k, token_type_ids=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
output_all_encoded_layers=True, head_mask=None): output_all_encoded_layers=True, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask. attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer. from previous batches. The length of the list equals n_layer.
...@@ -945,8 +945,8 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -945,8 +945,8 @@ class XLNetModel(XLNetPreTrainedModel):
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
inp_k = inp_k.transpose(0, 1).contiguous() inp_k = inp_k.transpose(0, 1).contiguous()
seg_id = seg_id.transpose(0, 1).contiguous() if seg_id is not None else None token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None
...@@ -969,11 +969,11 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -969,11 +969,11 @@ class XLNetModel(XLNetPreTrainedModel):
raise ValueError('Unsupported attention type: {}'.format(self.attn_type)) raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask # data mask: input mask & perm mask
if input_mask is not None and perm_mask is not None: if attention_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask data_mask = attention_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None: elif attention_mask is not None and perm_mask is None:
data_mask = input_mask[None] data_mask = attention_mask[None]
elif input_mask is None and perm_mask is not None: elif attention_mask is None and perm_mask is not None:
data_mask = perm_mask data_mask = perm_mask
else: else:
data_mask = None data_mask = None
...@@ -1011,13 +1011,13 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1011,13 +1011,13 @@ class XLNetModel(XLNetPreTrainedModel):
output_g = None output_g = None
##### Segment embedding ##### Segment embedding
if seg_id is not None: if token_type_ids is not None:
# Convert `seg_id` to one-hot `seg_mat` # Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device) mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
cat_ids = torch.cat([mem_pad, seg_id], dim=0) cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
# `1` indicates not in the same segment [qlen x klen x bsz] # `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (seg_id[:, None] != cat_ids[None, :]).long() seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float) seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
else: else:
seg_mat = None seg_mat = None
...@@ -1076,8 +1076,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1076,8 +1076,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
Inputs: Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask. attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer. from previous batches. The length of the list equals n_layer.
...@@ -1112,14 +1112,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1112,14 +1112,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
```python ```python
# Already been converted into WordPiece token ids # Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768, config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
n_layer=12, num_attention_heads=12, intermediate_size=3072) n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config) model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, attention_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
...@@ -1134,7 +1134,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1134,7 +1134,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# Tie weights # Tie weights
self.apply(self.init_xlnet_weights) self.apply(self.init_weights)
self.tie_weights() self.tie_weights()
def tie_weights(self): def tie_weights(self):
...@@ -1142,14 +1142,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1142,14 +1142,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
""" """
self.lm_loss.weight = self.transformer.word_embedding.weight self.lm_loss.weight = self.transformer.word_embedding.weight
def forward(self, inp_k, seg_id=None, input_mask=None, def forward(self, inp_k, token_type_ids=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
target=None, output_all_encoded_layers=True, head_mask=None): target=None, output_all_encoded_layers=True, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask. attention_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer. from previous batches. The length of the list equals n_layer.
...@@ -1171,7 +1171,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1171,7 +1171,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation. to pool the input to get a vector representation.
""" """
output, hidden_states, new_mems = self.transformer(inp_k, seg_id, input_mask, output, hidden_states, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
mems, perm_mask, target_mapping, inp_q, mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask) output_all_encoded_layers, head_mask)
...@@ -1200,7 +1200,7 @@ class XLNetSequenceSummary(nn.Module): ...@@ -1200,7 +1200,7 @@ class XLNetSequenceSummary(nn.Module):
super(XLNetSequenceSummary, self).__init__() super(XLNetSequenceSummary, self).__init__()
self.summary_type = summary_type self.summary_type = summary_type
if use_proj: if use_proj:
self.summary = nn.Linear(config.hidden_size, num_labels) self.summary = nn.Linear(config.d_model, config.d_model)
else: else:
self.summary = None self.summary = None
if summary_type == 'attn': if summary_type == 'attn':
...@@ -1211,19 +1211,20 @@ class XLNetSequenceSummary(nn.Module): ...@@ -1211,19 +1211,20 @@ class XLNetSequenceSummary(nn.Module):
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states, input_mask=None): def forward(self, hidden_states):
""" hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer."""
if self.summary_type == 'last': if self.summary_type == 'last':
output = hidden_states[-1] output = hidden_states[:, -1]
elif self.summary_type == 'first': elif self.summary_type == 'first':
output = hidden_states[0] output = hidden_states[:, 0]
elif self.summary_type == 'mean': elif self.summary_type == 'mean':
output = hidden_states.mean(dim=0) output = hidden_states.mean(dim=1)
elif summary_type == 'attn': elif summary_type == 'attn':
raise NotImplementedError raise NotImplementedError
output = self.summary(output) output = self.summary(output)
output = self.dropout(output)
output = self.activation(output) output = self.activation(output)
output = self.dropout(output)
return output return output
...@@ -1240,8 +1241,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1240,8 +1241,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Inputs: Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask. attention_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer. from previous batches. The length of the list equals n_layer.
...@@ -1277,14 +1278,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1277,14 +1278,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
```python ```python
# Already been converted into WordPiece token ids # Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768, config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
n_layer=12, num_attention_heads=12, intermediate_size=3072) n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config) model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, attention_mask)
``` ```
""" """
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2, def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
...@@ -1302,17 +1303,17 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1302,17 +1303,17 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type, self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
use_proj=use_proj, output_attentions=output_attentions, use_proj=use_proj, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.loss_proj = nn.Linear(config.d_model, num_classes if not is_regression else 1) self.loss_proj = nn.Linear(config.d_model, num_labels if not is_regression else 1)
self.apply(self.init_bert_weights) self.apply(self.init_weights)
def forward(self, inp_k, seg_id=None, input_mask=None, def forward(self, inp_k, token_type_ids=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
target=None, output_all_encoded_layers=True, head_mask=None): target=None, output_all_encoded_layers=True, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask. attention_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer. from previous batches. The length of the list equals n_layer.
...@@ -1331,7 +1332,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1331,7 +1332,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention. Only used during pretraining for two-stream attention.
Set to None during finetuning. Set to None during finetuning.
""" """
output, _, new_mems = self.transformer(inp_k, seg_id, input_mask, output, _, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
mems, perm_mask, target_mapping, inp_q, mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask) output_all_encoded_layers, head_mask)
...@@ -1356,3 +1357,96 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1356,3 +1357,96 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
# if self.output_attentions: # if self.output_attentions:
return logits, new_mems return logits, new_mems
# return all_attentions, encoded_layers, pooled_output # return all_attentions, encoded_layers, pooled_output
class XLNetForQuestionAnswering(XLNetPreTrainedModel):
"""XLNet model for Question Answering (span extraction).
This module is composed of the XLNet model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Params:
`config`: a XLNetConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLNet paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = XLNetConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = XLNetForQuestionAnswering(config)
start_logits, end_logits = model(input_ids, token_type_ids, attention_mask)
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(XLNetForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions
self.transformer = XLNetModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
start_positions=None, end_positions=None,
output_all_encoded_layers=True, head_mask=None):
output, _, new_mems = self.transformer(inp_k, token_type_ids, attention_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
logits = self.qa_outputs(output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
elif self.output_attentions:
return all_attentions, start_logits, end_logits
return start_logits, end_logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment