Commit fbe04423 authored by thomwolf's avatar thomwolf
Browse files

Common SequenceSummary class

parent c22545aa
...@@ -17,7 +17,7 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, ...@@ -17,7 +17,7 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl) load_tf_weights_in_transfo_xl)
from .modeling_gpt2 import (GPT2Config, GPT2Model, from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2MultipleChoiceHead, GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2) load_tf_weights_in_gpt2)
from .modeling_xlnet import (XLNetConfig, from .modeling_xlnet import (XLNetConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
......
...@@ -282,6 +282,95 @@ class PreTrainedModel(nn.Module): ...@@ -282,6 +282,95 @@ class PreTrainedModel(nn.Module):
return model return model
class Conv1D(nn.Module):
def __init__(self, nf, nx):
""" Conv1D layer as defined by Alec for GPT (and also used in GPT-2)
Basically works like a Linear layer but the weights are transposed
"""
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
class SequenceSummary(nn.Module):
def __init__(self, config):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
- 'mean' => take the mean of all tokens hidden states
- 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size)
summary_activation:
'tanh' => add a tanh activation to the output
None => no activation
"""
super(SequenceSummary, self).__init__()
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
if config.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.summary = nn.Identity()
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
if hasattr(config, 'summary_num_classes') and config.summary_num_classes > 0:
num_classes = config.summary_num_classes
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
self.activation = nn.Identity()
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
self.activation = nn.Tanh()
self.dropout = nn.Dropout(config.summary_dropout)
def forward(self, hidden_states, token_ids=None):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
token_ids: [optional] index of the classification token if summary_type == 'token_ids',
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
if summary_type == 'token_ids' and token_ids is None:
we take the last token of the sequence as classification token
"""
if self.summary_type == 'last':
output = hidden_states[:, -1]
elif self.summary_type == 'first':
output = hidden_states[:, 0]
elif self.summary_type == 'mean':
output = hidden_states.mean(dim=1)
elif self.summary_type == 'token_ids':
if token_ids is None:
token_ids = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
else:
token_ids = token_ids.unsqueeze(-1).unsqueeze(-1)
token_ids = token_ids.expand((-1,) * (token_ids.dim()-1) + (hidden_states.size(-1),))
# shape of token_ids: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, token_ids).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == 'attn':
raise NotImplementedError
output = self.summary(output)
output = self.activation(output)
output = self.dropout(output)
return output
def prune_linear_layer(layer, index, dim=0): def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index. """ Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True. Return the pruned layer as a new layer with requires_grad=True.
...@@ -307,25 +396,6 @@ def prune_linear_layer(layer, index, dim=0): ...@@ -307,25 +396,6 @@ def prune_linear_layer(layer, index, dim=0):
return new_layer return new_layer
class Conv1D(nn.Module):
""" Conv1D layer as defined by Alec Radford for GPT (and also used in GPT-2)
Basically works like a Linear layer but the weights are transposed
"""
def __init__(self, nf, nx):
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
def prune_conv1d_layer(layer, index, dim=1): def prune_conv1d_layer(layer, index, dim=1):
""" Prune a Conv1D layer (a model parameters) to keep only entries in index. """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed. A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
......
...@@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss ...@@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -119,6 +120,11 @@ class GPT2Config(PretrainedConfig): ...@@ -119,6 +120,11 @@ class GPT2Config(PretrainedConfig):
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True, predict_special_tokens=True,
summary_type='token_ids',
summary_use_proj=True,
summary_num_classes=1,
summary_activation=None,
summary_dropout=0.1,
**kwargs **kwargs
): ):
"""Constructs GPT2Config. """Constructs GPT2Config.
...@@ -164,6 +170,11 @@ class GPT2Config(PretrainedConfig): ...@@ -164,6 +170,11 @@ class GPT2Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens self.predict_special_tokens = predict_special_tokens
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_num_classes = summary_num_classes
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
else: else:
raise ValueError( raise ValueError(
"First argument must be either a vocabulary size (int)" "First argument must be either a vocabulary size (int)"
...@@ -342,37 +353,6 @@ class GPT2LMHead(nn.Module): ...@@ -342,37 +353,6 @@ class GPT2LMHead(nn.Module):
return lm_logits return lm_logits
class GPT2MultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, config):
super(GPT2MultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1)
nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, mc_token_ids=None):
""" Extract classification token hidden state and project it using self.linear
hidden_state: shape (bsz, num_choices, seq_length, hidden_size)
mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
if mc_token_ids=None we take the last token of the sequence as classification token
"""
if mc_token_ids is None:
mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long)
else:
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# mc_token_ids has shape (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size)
multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
# (bsz, num_choices)
return multiple_choice_logits
class GPT2PreTrainedModel(PreTrainedModel): class GPT2PreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
...@@ -735,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -735,7 +715,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
super(GPT2DoubleHeadsModel, self).__init__(config) super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.multiple_choice_head = SequenceSummary(config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -753,7 +733,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -753,7 +733,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
outputs = (lm_logits, mc_logits) + transformer_outputs[1:] outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_labels is not None: if mc_labels is not None:
......
...@@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss ...@@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -147,6 +148,11 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -147,6 +148,11 @@ class OpenAIGPTConfig(PretrainedConfig):
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True, predict_special_tokens=True,
summary_type='token_ids',
summary_use_proj=True,
summary_num_classes=1,
summary_activation=None,
summary_dropout=0.1,
**kwargs **kwargs
): ):
"""Constructs OpenAIGPTConfig. """Constructs OpenAIGPTConfig.
...@@ -195,6 +201,11 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -195,6 +201,11 @@ class OpenAIGPTConfig(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens self.predict_special_tokens = predict_special_tokens
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_num_classes = summary_num_classes
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
else: else:
raise ValueError( raise ValueError(
"First argument must be either a vocabulary size (int)" "First argument must be either a vocabulary size (int)"
...@@ -368,37 +379,6 @@ class OpenAIGPTLMHead(nn.Module): ...@@ -368,37 +379,6 @@ class OpenAIGPTLMHead(nn.Module):
return lm_logits return lm_logits
class OpenAIGPTMultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, config):
super(OpenAIGPTMultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1)
nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, mc_token_ids=None):
""" Extract classification token hidden state and project it using self.linear
hidden_state: hidden state of shape (bsz, num_choices, seq_length, hidden_size)
mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
if mc_token_ids=None we take the last token of the sequence as classification token
"""
if mc_token_ids is None:
mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long)
else:
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size)
multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
# (bsz, num_choices)
return multiple_choice_logits
class OpenAIGPTPreTrainedModel(PreTrainedModel): class OpenAIGPTPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
...@@ -768,9 +748,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -768,9 +748,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config) super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) self.multiple_choice_head = SequenceSummary(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
...@@ -787,7 +769,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -787,7 +769,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
outputs = (lm_logits, mc_logits) + transformer_outputs[1:] outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_labels is not None: if mc_labels is not None:
......
...@@ -32,7 +32,8 @@ from torch.nn import functional as F ...@@ -32,7 +32,8 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel from .model_utils import (CONFIG_NAME, WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, SequenceSummary)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -223,8 +224,10 @@ class XLNetConfig(PretrainedConfig): ...@@ -223,8 +224,10 @@ class XLNetConfig(PretrainedConfig):
finetuning_task=None, finetuning_task=None,
num_labels=2, num_labels=2,
summary_type="last", summary_type='last',
use_proj=True, summary_use_proj=True,
summary_activation='tanh',
summary_dropout=0.1,
**kwargs): **kwargs):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
...@@ -307,7 +310,9 @@ class XLNetConfig(PretrainedConfig): ...@@ -307,7 +310,9 @@ class XLNetConfig(PretrainedConfig):
self.finetuning_task = finetuning_task self.finetuning_task = finetuning_task
self.num_labels = num_labels self.num_labels = num_labels
self.summary_type = summary_type self.summary_type = summary_type
self.use_proj = use_proj self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -1042,38 +1047,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1042,38 +1047,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
class XLNetSequenceSummary(nn.Module):
def __init__(self, config):
super(XLNetSequenceSummary, self).__init__()
self.summary_type = config.summary_type
if config.use_proj:
self.summary = nn.Linear(config.d_model, config.d_model)
else:
self.summary = None
if config.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.dropout = nn.Dropout(config.dropout)
self.activation = nn.Tanh()
def forward(self, hidden_states):
""" hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer."""
if self.summary_type == 'last':
output = hidden_states[:, -1]
elif self.summary_type == 'first':
output = hidden_states[:, 0]
elif self.summary_type == 'mean':
output = hidden_states.mean(dim=1)
elif self.summary_type == 'attn':
raise NotImplementedError
output = self.summary(output)
output = self.activation(output)
output = self.dropout(output)
return output
class XLNetForSequenceClassification(XLNetPreTrainedModel): class XLNetForSequenceClassification(XLNetPreTrainedModel):
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding"). """XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
...@@ -1143,7 +1116,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1143,7 +1116,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
super(XLNetForSequenceClassification, self).__init__(config) super(XLNetForSequenceClassification, self).__init__(config)
self.transformer = XLNetModel(config) self.transformer = XLNetModel(config)
self.sequence_summary = XLNetSequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, config.num_labels) self.logits_proj = nn.Linear(config.d_model, config.num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment