Commit f6081f22 authored by thomwolf's avatar thomwolf
Browse files

add xlnetforsequence classif and run_classifier example for xlnet

parent c946bb51
...@@ -14,7 +14,7 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -14,7 +14,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
from run_classifier_dataset_utils import processors, output_modes, convert_examples_to_features, compute_metrics from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -39,7 +39,7 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification ...@@ -39,7 +39,7 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from run_classifier_dataset_utils import processors, output_modes, convert_examples_to_features, compute_metrics from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
......
...@@ -38,7 +38,7 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering ...@@ -38,7 +38,7 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from run_squad_dataset_utils import read_squad_examples, convert_examples_to_features, RawResult, write_predictions from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
......
This diff is collapsed.
# coding=utf-8 # coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
......
...@@ -3,7 +3,7 @@ from pytorch_pretrained_bert.modeling_xlnet import ( ...@@ -3,7 +3,7 @@ from pytorch_pretrained_bert.modeling_xlnet import (
XLNetConfig, XLNetConfig,
XLNetModel, XLNetModel,
XLNetLMHeadModel, XLNetLMHeadModel,
XLNetForSequenceClassification # XLNetForSequenceClassification
) )
# A lot of models share the same param doc. Use a decorator # A lot of models share the same param doc. Use a decorator
...@@ -135,35 +135,35 @@ def xlnetLMHeadModel(*args, **kwargs): ...@@ -135,35 +135,35 @@ def xlnetLMHeadModel(*args, **kwargs):
return model return model
@_append_from_pretrained_docstring(xlnet_docstring) # @_append_from_pretrained_docstring(xlnet_docstring)
def xlnetForSequenceClassification(*args, **kwargs): # def xlnetForSequenceClassification(*args, **kwargs):
""" # """
xlnetModel is the basic XLNet Transformer model from # xlnetModel is the basic XLNet Transformer model from
"XLNet: Generalized Autoregressive Pretraining for Language Understanding" # "XLNet: Generalized Autoregressive Pretraining for Language Understanding"
by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le # by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
Example: # Example:
# Load the tokenizer # # Load the tokenizer
>>> import torch # >>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetTokenizer', 'xlnet-large-cased') # >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetTokenizer', 'xlnet-large-cased')
# Prepare tokenized input # # Prepare tokenized input
>>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" # >>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer"
>>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" # >>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man"
>>> tokenized_text1 = tokenizer.tokenize(text1) # >>> tokenized_text1 = tokenizer.tokenize(text1)
>>> tokenized_text2 = tokenizer.tokenize(text2) # >>> tokenized_text2 = tokenizer.tokenize(text2)
>>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) # >>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1)
>>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) # >>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2)
>>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) # >>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]])
>>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) # >>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]])
# Load xlnetForSequenceClassification # # Load xlnetForSequenceClassification
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetForSequenceClassification', 'xlnet-large-cased') # >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetForSequenceClassification', 'xlnet-large-cased')
>>> model.eval() # >>> model.eval()
# Predict sequence classes logits # # Predict sequence classes logits
>>> with torch.no_grad(): # >>> with torch.no_grad():
lm_logits, mems = model(tokens_tensor) # lm_logits, mems = model(tokens_tensor)
""" # """
model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs) # model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs)
return model # return model
...@@ -1194,6 +1194,38 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1194,6 +1194,38 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return logits, new_mems return logits, new_mems
# return all_attentions, encoded_layers, pooled_output # return all_attentions, encoded_layers, pooled_output
class XLNetSequenceSummary(nn.Module):
def __init__(self, config, summary_type="last", use_proj=True,
output_attentions=False, keep_multihead_output=False):
super(XLNetSequenceSummary, self).__init__()
self.summary_type = summary_type
if use_proj:
self.summary = nn.Linear(config.hidden_size, num_labels)
else:
self.summary = None
if summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.dropout = nn.Dropout(config.dropout)
self.activation = nn.Tanh()
def forward(self, hidden_states, input_mask=None):
if self.summary_type == 'last':
output = hidden_states[-1]
elif self.summary_type == 'first':
output = hidden_states[0]
elif self.summary_type == 'mean':
output = hidden_states.mean(dim=0)
elif summary_type == 'attn':
raise NotImplementedError
output = self.summary(output)
output = self.dropout(output)
output = self.activation(output)
return output
class XLNetForSequenceClassification(XLNetPreTrainedModel): class XLNetForSequenceClassification(XLNetPreTrainedModel):
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding"). """XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
...@@ -1255,19 +1287,23 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1255,19 +1287,23 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, summary_type="last", output_attentions=False, keep_multihead_output=False): def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
is_regression=False, output_attentions=False, keep_multihead_output=False):
super(XLNetForSequenceClassification, self).__init__(config) super(XLNetForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.attn_type = config.attn_type self.attn_type = config.attn_type
self.same_length = config.same_length self.same_length = config.same_length
self.summary_type = summary_type self.summary_type = summary_type
self.is_regression = is_regression
self.transformer = XLNetModel(config, output_attentions=output_attentions, self.transformer = XLNetModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
self.apply(self.init_xlnet_weights) self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
self.tie_weights() use_proj=use_proj, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.loss_proj = nn.Linear(config.d_model, num_classes if not is_regression else 1)
self.apply(self.init_bert_weights)
def forward(self, inp_k, seg_id=None, input_mask=None, def forward(self, inp_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
...@@ -1295,17 +1331,20 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1295,17 +1331,20 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention. Only used during pretraining for two-stream attention.
Set to None during finetuning. Set to None during finetuning.
""" """
output, hidden_states, new_mems = self.transformer(inp_k, seg_id, input_mask, output, _, new_mems = self.transformer(inp_k, seg_id, input_mask,
mems, perm_mask, target_mapping, inp_q, mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask) output_all_encoded_layers, head_mask)
logits = self.lm_loss(output) output = self.sequence_summary(output)
logits = self.loss_proj(output)
if target is not None: if target is not None:
# Flatten the tokens if self.is_regression:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), loss = loss_fct(logits.view(-1), target.view(-1))
target.view(-1)) else:
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)), target.view(-1))
return loss, new_mems return loss, new_mems
# if self.output_attentions: # if self.output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment