Commit 0ef9bc92 authored by thomwolf's avatar thomwolf
Browse files

Cleaning up seq2seq [WIP]

parent b3261e7a
...@@ -199,12 +199,14 @@ class BertSelfAttention(nn.Module): ...@@ -199,12 +199,14 @@ class BertSelfAttention(nn.Module):
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
mixed_key_layer = self.key(hidden_states) mixed_query_layer = self.query(hidden_states)
mixed_value_layer = self.value(hidden_states) # if the attention Module is a encoder-decoder self attention module
if encoder_hidden_states is not None: # if encoder-decoder attention if encoder_hidden_states is not None:
mixed_query_layer = self.query(encoder_hidden_states) mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
else: else:
mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer) query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer) key_layer = self.transpose_for_scores(mixed_key_layer)
...@@ -322,26 +324,25 @@ class BertLayer(nn.Module): ...@@ -322,26 +324,25 @@ class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertLayer, self).__init__() super(BertLayer, self).__init__()
self.attention = BertAttention(config) self.attention = BertAttention(config)
if getattr(config, "is_decoder", False): self.is_decoder = config.is_decoder
if self.is_decoder:
self.crossattention = BertAttention(config) self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask) self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if encoder_hidden_state is not None: if self.is_decoder and encoder_hidden_state is not None:
try: cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state) attention_output = cross_attention_outputs[0]
except AttributeError as ae: outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae)
raise
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + outputs
return outputs return outputs
...@@ -352,14 +353,14 @@ class BertEncoder(nn.Module): ...@@ -352,14 +353,14 @@ class BertEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.output_attentions: if self.output_attentions:
...@@ -377,42 +378,6 @@ class BertEncoder(nn.Module): ...@@ -377,42 +378,6 @@ class BertEncoder(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertDecoder(nn.Module):
def __init__(self, config):
super(BertDecoder, self).__init__()
config.is_decoder = True
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_state=encoder_outputs)
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = layer_outputs[0]
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module): class BertPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertPooler, self).__init__() super(BertPooler, self).__init__()
...@@ -635,7 +600,8 @@ class BertModel(BertPreTrainedModel): ...@@ -635,7 +600,8 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_state=None):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
if token_type_ids is None: if token_type_ids is None:
...@@ -673,8 +639,9 @@ class BertModel(BertPreTrainedModel): ...@@ -673,8 +639,9 @@ class BertModel(BertPreTrainedModel):
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output, encoder_outputs = self.encoder(embedding_output,
extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask) head_mask=head_mask,
encoder_hidden_state=encoder_hidden_state)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
...@@ -682,111 +649,6 @@ class BertModel(BertPreTrainedModel): ...@@ -682,111 +649,6 @@ class BertModel(BertPreTrainedModel):
return outputs # sequence_output, pooled_output, (hidden_states), (attentions) return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top.
The model follows the general transformer decoder architecture.""",
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
class BertDecoderModel(BertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertDecoderModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(BertDecoderModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.decoder = BertDecoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.decoder.layer[layer].attention.prune_heads(heads)
self.decoder.layer[layer].crossattention.prune_heads(heads)
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
decoder_outputs = self.decoder(embedding_output,
encoder_outputs,
extended_attention_mask,
head_mask=head_mask)
sequence_output = decoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training: @add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
a `masked language modeling` head and a `next sentence prediction (classification)` head. """, a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
...@@ -1309,101 +1171,3 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1309,101 +1171,3 @@ class BertForQuestionAnswering(BertPreTrainedModel):
outputs = (total_loss,) + outputs outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
class Bert2Rnd(BertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = Bert2Rnd.from_pretrained('bert-base-uncased')
# fine-tuning magic happens here
input = tokenizer.encode("Hello, how are you?")
outputs = model(input)
output_text = tokenize.decode(outputs[0])
print(output_text)
References::
[1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1
[2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor
"""
def __init__(self, config):
super(Bert2Rnd, self).__init__(config)
self.encoder = BertModel(config)
self.decoder = BertDecoderModel(config)
@classmethod
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
""" Load the pretrained weights in the encoder.
The encoder of `Bert2Rand` is initialized with pretrained weights; the
weights of the decoder are initialized at random except the embeddings
which are initialized with the pretrained embeddings. We thus need to override
the base class' `from_pretrained` method.
"""
# Load the configuration
config = model_kwargs.pop('config', None)
if config is None:
cache_dir = model_kwargs.pop('cache_dir', None)
force_download = model_kwargs.pop('force_download', False)
config, _ = cls.config_class.from_pretrained(
pretrained_model_or_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
**model_kwargs
)
model = cls(config)
# We load the encoder with pretrained weights
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
model.encoder = pretrained_encoder
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
def randomize_decoder_weights(module):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
pretrained_decoder.apply(randomize_decoder_weights)
model.decoder = pretrained_decoder
return model
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
encoder_outputs = self.encoder(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
decoder_outputs = self.decoder(input_ids,
encoder_outputs[0],
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
return decoder_outputs
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import torch
from torch import nn
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
from .modeling_utils import PreTrainedModel, SequenceSummary
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module):
r"""
:class:`~transformers.Seq2seq` is a generic model class
that will be instantiated as a Seq2seq model with one of the base model classes of the library
as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method.
The `from_pretrained()` method takes care of returning the correct model class instance
using pattern matching on the `pretrained_model_name_or_path` string.
The base model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
This class cannot be instantiated using `__init__()` (throws an error).
"""
def __init__(self, encoder, decoder):
super(PreTrainedSeq2seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library
from a pre-trained model configuration.
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: DistilBertModel (DistilBERT model)
- contains `roberta`: RobertaModel (RoBERTa model)
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with `model.train()`
Params:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
state_dict: (`optional`) dict:
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
Examples::
model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
# Extract encoder and decoder model if provided
encoder_model = kwargs.pop('encoder_model', None)
decoder_model = kwargs.pop('decoder_model', None)
# Extract decoder kwargs so we only have encoder kwargs for now
if decoder_model is None:
decoder_pretrained_model_name_or_path = kwargs.pop('decoder_pretrained_model_name_or_path', pretrained_model_name_or_path)
decoder_kwargs = {}
for key in kwargs.keys():
if key.startswith('decoder_'):
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
# Load and initialize the decoder
if encoder_model:
encoder = encoder_model
else:
# Load and initialize the encoder
kwargs['is_decoder'] = False # Make sure the encoder will be an encoder
if 'distilbert' in pretrained_model_name_or_path:
encoder = DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
encoder = RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
encoder = BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'openai-gpt' in pretrained_model_name_or_path:
encoder = OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'gpt2' in pretrained_model_name_or_path:
encoder = GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'transfo-xl' in pretrained_model_name_or_path:
encoder = TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlnet' in pretrained_model_name_or_path:
encoder = XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
encoder = XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
else:
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
# Load and initialize the decoder
if decoder_model:
decoder = decoder_model
else:
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder
if 'distilbert' in decoder_pretrained_model_name_or_path:
decoder = DistilBertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'roberta' in decoder_pretrained_model_name_or_path:
decoder = RobertaModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'bert' in decoder_pretrained_model_name_or_path:
decoder = BertModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'openai-gpt' in decoder_pretrained_model_name_or_path:
decoder = OpenAIGPTModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'gpt2' in decoder_pretrained_model_name_or_path:
decoder = GPT2Model.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'transfo-xl' in decoder_pretrained_model_name_or_path:
decoder = TransfoXLModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'xlnet' in decoder_pretrained_model_name_or_path:
decoder = XLNetModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
elif 'xlm' in decoder_pretrained_model_name_or_path:
decoder = XLMModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
else:
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path))
model = cls(encoder, decoder)
return model
def forward(self, *inputs, *kwargs):
# Extract decoder inputs
decoder_kwargs = {}
for key in kwargs.keys():
if key.startswith('decoder_'):
decoder_kwargs[key.replace('decoder_', '')] = kwargs.pop(key)
# Compute encoder hidden states if needed
encoder_hidden_states = kwargs.pop('encoder_hidden_states', None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(*inputs, *kwargs)
encoder_hidden_states = encoder_outputs[0]
# Decode
decoder_kwargs['encoder_hidden_states'] = encoder_hidden_states
decoder_outputs = self.decoder(**decoder_kwargs)
return decoder_outputs
class Model2Model(PreTrainedSeq2seq):
def tie_weights():
# We should tie encoder and decoder embeddings if possible here
pass
class Model2LSTM(PreTrainedSeq2seq):
@classmethod
def from_pretrained(cls, *args, **kwargs):
if kwargs.get('decoder_model', None) is None:
# We will create a randomly initilized LSTM model as decoder
if 'decoder_config' not in kwargs:
raise ValueError("To load an LSTM in Seq2seq model, please supply either: "
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or "
" - a dictionary of configuration parameters that will be used to initialize a
" torch.nn.LSTM model as `decoder_config` keyword argument. "
" E.g. `decoder_config=\{'input_size': 768, 'hidden_size': 768, 'num_layers': 2\}`")
kwargs['decoder_model'] = torch.nn.LSTM(kwarg.pop('decoder_config'))
model = super(Model2LSTM, cls).from_pretrained(*args, **kwargs)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment