Commit f6081f22 authored by thomwolf's avatar thomwolf
Browse files

add xlnetforsequence classif and run_classifier example for xlnet

parent c946bb51
......@@ -14,7 +14,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
from run_classifier_dataset_utils import processors, output_modes, convert_examples_to_features, compute_metrics
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
logger = logging.getLogger(__name__)
......
......@@ -39,7 +39,7 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from run_classifier_dataset_utils import processors, output_modes, convert_examples_to_features, compute_metrics
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
if sys.version_info[0] == 2:
import cPickle as pickle
......
......@@ -38,7 +38,7 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from pytorch_pretrained_bert.tokenization import BertTokenizer
from run_squad_dataset_utils import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
if sys.version_info[0] == 2:
import cPickle as pickle
......
This diff is collapsed.
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
......
......@@ -3,7 +3,7 @@ from pytorch_pretrained_bert.modeling_xlnet import (
XLNetConfig,
XLNetModel,
XLNetLMHeadModel,
XLNetForSequenceClassification
# XLNetForSequenceClassification
)
# A lot of models share the same param doc. Use a decorator
......@@ -135,35 +135,35 @@ def xlnetLMHeadModel(*args, **kwargs):
return model
@_append_from_pretrained_docstring(xlnet_docstring)
def xlnetForSequenceClassification(*args, **kwargs):
"""
xlnetModel is the basic XLNet Transformer model from
"XLNet: Generalized Autoregressive Pretraining for Language Understanding"
by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetTokenizer', 'xlnet-large-cased')
# Prepare tokenized input
>>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer"
>>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man"
>>> tokenized_text1 = tokenizer.tokenize(text1)
>>> tokenized_text2 = tokenizer.tokenize(text2)
>>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1)
>>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2)
>>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]])
>>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]])
# Load xlnetForSequenceClassification
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetForSequenceClassification', 'xlnet-large-cased')
>>> model.eval()
# Predict sequence classes logits
>>> with torch.no_grad():
lm_logits, mems = model(tokens_tensor)
"""
model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs)
return model
# @_append_from_pretrained_docstring(xlnet_docstring)
# def xlnetForSequenceClassification(*args, **kwargs):
# """
# xlnetModel is the basic XLNet Transformer model from
# "XLNet: Generalized Autoregressive Pretraining for Language Understanding"
# by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
# Example:
# # Load the tokenizer
# >>> import torch
# >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetTokenizer', 'xlnet-large-cased')
# # Prepare tokenized input
# >>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer"
# >>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man"
# >>> tokenized_text1 = tokenizer.tokenize(text1)
# >>> tokenized_text2 = tokenizer.tokenize(text2)
# >>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1)
# >>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2)
# >>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]])
# >>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]])
# # Load xlnetForSequenceClassification
# >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetForSequenceClassification', 'xlnet-large-cased')
# >>> model.eval()
# # Predict sequence classes logits
# >>> with torch.no_grad():
# lm_logits, mems = model(tokens_tensor)
# """
# model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs)
# return model
......@@ -1194,6 +1194,38 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return logits, new_mems
# return all_attentions, encoded_layers, pooled_output
class XLNetSequenceSummary(nn.Module):
def __init__(self, config, summary_type="last", use_proj=True,
output_attentions=False, keep_multihead_output=False):
super(XLNetSequenceSummary, self).__init__()
self.summary_type = summary_type
if use_proj:
self.summary = nn.Linear(config.hidden_size, num_labels)
else:
self.summary = None
if summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.dropout = nn.Dropout(config.dropout)
self.activation = nn.Tanh()
def forward(self, hidden_states, input_mask=None):
if self.summary_type == 'last':
output = hidden_states[-1]
elif self.summary_type == 'first':
output = hidden_states[0]
elif self.summary_type == 'mean':
output = hidden_states.mean(dim=0)
elif summary_type == 'attn':
raise NotImplementedError
output = self.summary(output)
output = self.dropout(output)
output = self.activation(output)
return output
class XLNetForSequenceClassification(XLNetPreTrainedModel):
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
......@@ -1255,19 +1287,23 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, summary_type="last", output_attentions=False, keep_multihead_output=False):
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
is_regression=False, output_attentions=False, keep_multihead_output=False):
super(XLNetForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.attn_type = config.attn_type
self.same_length = config.same_length
self.summary_type = summary_type
self.is_regression = is_regression
self.transformer = XLNetModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
self.apply(self.init_xlnet_weights)
self.tie_weights()
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
use_proj=use_proj, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.loss_proj = nn.Linear(config.d_model, num_classes if not is_regression else 1)
self.apply(self.init_bert_weights)
def forward(self, inp_k, seg_id=None, input_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
......@@ -1295,17 +1331,20 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention.
Set to None during finetuning.
"""
output, hidden_states, new_mems = self.transformer(inp_k, seg_id, input_mask,
output, _, new_mems = self.transformer(inp_k, seg_id, input_mask,
mems, perm_mask, target_mapping, inp_q,
output_all_encoded_layers, head_mask)
logits = self.lm_loss(output)
output = self.sequence_summary(output)
logits = self.loss_proj(output)
if target is not None:
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)),
target.view(-1))
if self.is_regression:
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), target.view(-1))
else:
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)), target.view(-1))
return loss, new_mems
# if self.output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment