Unverified Commit 2431fea9 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Merge pull request #1383 from keskarnitish/master

Adding CTRL
parents e84470ef d9e60f4f
......@@ -131,4 +131,7 @@ examples/runs
# data
/data
serialization_dir
\ No newline at end of file
serialization_dir
# emacs
*.*~
\ No newline at end of file
......@@ -22,7 +22,7 @@
<p>State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch
</h3>
🤗 Transformers (formerly known as `pytorch-transformers` and `pytorch-pretrained-bert`) provides state-of-the-art general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch.
🤗 Transformers (formerly known as `pytorch-transformers` and `pytorch-pretrained-bert`) provides state-of-the-art general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet, CTRL...) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch.
### Features
......@@ -121,6 +121,7 @@ At some point in the future, you'll be able to seamlessly move from pre-training
6. **[XLM](https://github.com/facebookresearch/XLM/)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
8. **[DistilBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation).
9. **[CTRL](https://github.com/salesforce/ctrl/)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
......@@ -147,6 +148,7 @@ from transformers import *
MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'),
(OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'),
(GPT2Model, GPT2Tokenizer, 'gpt2'),
(CTRLModel, CTRLTokenizer, 'ctrl'),
(TransfoXLModel, TransfoXLTokenizer, 'transfo-xl-wt103'),
(XLNetModel, XLNetTokenizer, 'xlnet-base-cased'),
(XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024'),
......@@ -252,7 +254,7 @@ The library comprises several example scripts with SOTA performances for NLU and
- `run_glue.py`: an example fine-tuning Bert, XLNet and XLM on nine different GLUE tasks (*sequence-level classification*)
- `run_squad.py`: an example fine-tuning Bert, XLNet and XLM on the question answering dataset SQuAD 2.0 (*token-level classification*)
- `run_generation.py`: an example using GPT, GPT-2, Transformer-XL and XLNet for conditional language generation
- `run_generation.py`: an example using GPT, GPT-2, CTRL, Transformer-XL and XLNet for conditional language generation
- other model-specific examples (see the documentation).
Here are three quick usage examples for these scripts:
......@@ -390,7 +392,7 @@ python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ../models/wwm_uncase
This is the model provided as `bert-large-uncased-whole-word-masking-finetuned-squad`.
### `run_generation.py`: Text generation with GPT, GPT-2, Transformer-XL and XLNet
### `run_generation.py`: Text generation with GPT, GPT-2, CTRL, Transformer-XL and XLNet
A conditional generation script is also included to generate text from a prompt.
The generation script includes the [tricks](https://github.com/rusiaaman/XLNet-gen#methodology) proposed by Aman Rusia to get high-quality generation with memory models like Transformer-XL and XLNet (include a predefined text to make short inputs longer).
......@@ -404,6 +406,16 @@ python ./examples/run_generation.py \
--model_name_or_path=gpt2 \
```
and from the Salesforce CTRL model:
```shell
python ./examples/run_generation.py \
--model_type=ctrl \
--length=20 \
--model_name_or_path=gpt2 \
--temperature=0 \
--repetition_penalty=1.2 \
```
## Migrating from pytorch-transformers to transformers
Here is a quick summary of what you should take care of when migrating from `pytorch-transformers` to `transformers`.
......
......@@ -129,4 +129,8 @@ Here is the full list of the currently provided pretrained models together with
| | | | The DistilGPT2 model distilled from the GPT2 model `gpt2` checkpoint. |
| | | (see `details <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters |
| | | | Salesforce's Large-sized CTRL English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
.. <https://huggingface.co/transformers/examples.html>`__
\ No newline at end of file
......@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
"""
from __future__ import absolute_import, division, print_function, unicode_literals
......@@ -26,12 +26,13 @@ import torch
import torch.nn.functional as F
import numpy as np
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
......@@ -42,10 +43,11 @@ logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
......@@ -105,8 +107,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False,
xlm_lang=None, device='cpu'):
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, is_xlnet=False, xlm_lang=None, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
......@@ -128,9 +129,17 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits = outputs[0][0, -1, :] / temperature
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for _ in set(generated):
next_token_logits[_] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if temperature == 0: #greedy sampling:
next_token = torch.argmax(filtered_logits).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
return generated
......@@ -145,7 +154,10 @@ def main():
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0,
help="primarily useful for CTRL model; in that case, use 1.2")
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true',
......@@ -155,7 +167,10 @@ def main():
parser.add_argument('--stop_token', type=str, default=None,
help="Token at which text generation is stopped")
args = parser.parse_args()
if args.model_type in ["ctrl"]:
if args.temperature > 0.7 :
print('CTRL typically works better with lower temperatures (and lower top_k).')
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
......@@ -201,6 +216,7 @@ def main():
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
is_xlnet=bool(args.model_type == "xlnet"),
xlm_lang=xlm_lang,
device=args.device,
......
......@@ -37,6 +37,7 @@ from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer
......@@ -49,7 +50,9 @@ from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
......@@ -73,6 +76,9 @@ if is_torch_available():
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_ctrl import (CTRLPreTrainedModel, CTRLModel,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForMultipleChoice,
XLNetForQuestionAnsweringSimple, XLNetForQuestionAnswering,
......@@ -149,6 +155,11 @@ if is_tf_available():
load_distilbert_pt_weights_in_tf2,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
TFCTRLLMHeadModel,
load_ctrl_pt_weights_in_tf2,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
# TF 2.0 <=> PyTorch conversion utilities
if is_tf_available() and is_torch_available():
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
......
......@@ -26,6 +26,7 @@ from .configuration_xlnet import XLNetConfig
from .configuration_xlm import XLMConfig
from .configuration_roberta import RobertaConfig
from .configuration_distilbert import DistilBertConfig
from .configuration_ctrl import CTRLConfig
logger = logging.getLogger(__name__)
......@@ -49,7 +50,7 @@ class AutoConfig(object):
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlm`: XLMConfig (XLM model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `ctrl` : CTRLConfig (CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
"""
def __init__(self):
......@@ -71,7 +72,7 @@ class AutoConfig(object):
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlm`: XLMConfig (XLM model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `ctrl` : CTRLConfig (CTRL model)
Params:
pretrained_model_name_or_path: either:
......@@ -129,7 +130,8 @@ class AutoConfig(object):
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path))
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Salesforce CTRL configuration """
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import sys
from io import open
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
class CTRLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `CTRLModel`.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
dff: Size of the inner dimension of the FFN.
n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers
resid_pdrop: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attn_pdrop: The dropout ratio for the attention
probabilities.
embd_pdrop: The dropout ratio for the embeddings.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(
self,
vocab_size_or_config_json_file=246534,
n_positions=256,
n_ctx=256,
n_embd=1280,
dff=8192,
n_layer=48,
n_head=16,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
num_labels=1,
summary_type='cls_index',
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
**kwargs
):
"""Constructs CTRLConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
dff: Size of the inner dimension of the FFN.
n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers
resid_pdrop: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attn_pdrop: The dropout ratio for the attention
probabilities.
embd_pdrop: The dropout ratio for the embeddings.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
super(CTRLConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.dff = dff
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property
def max_position_embeddings(self):
return self.n_positions
@property
def hidden_size(self):
return self.n_embd
@property
def num_attention_heads(self):
return self.n_head
@property
def num_hidden_layers(self):
return self.n_layer
......@@ -31,7 +31,8 @@ from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAns
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_torch_available():
import torch
......@@ -43,7 +44,8 @@ if is_torch_available():
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
......@@ -52,7 +54,8 @@ else:
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = (
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) = (
None, None, None, None,
None, None,
None, None,
......@@ -60,7 +63,8 @@ else:
None, None,
None, None,
None, None, None,
None, None, None,)
None, None, None,
None, None)
import logging
......@@ -80,6 +84,7 @@ MODEL_CLASSES = {
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
}
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
......
......@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try:
import tensorflow as tf
assert int(tf.__version__[0]) >= 2
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
except (ImportError, AssertionError):
......
......@@ -21,6 +21,7 @@ import logging
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
......@@ -51,6 +52,7 @@ class AutoModel(object):
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
......@@ -73,6 +75,7 @@ class AutoModel(object):
- contains `bert`: BertModel (Bert model)
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
- contains `xlnet`: XLNetModel (XLNet model)
- contains `xlm`: XLMModel (XLM model)
......@@ -149,10 +152,11 @@ class AutoModel(object):
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
"'xlm', 'roberta, 'ctrl'".format(pretrained_model_name_or_path))
class AutoModelWithLMHead(object):
......@@ -172,6 +176,7 @@ class AutoModelWithLMHead(object):
- contains `bert`: BertForMaskedLM (Bert model)
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
- contains `ctrl`: CTRLLMModel (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
- contains `xlm`: XLMWithLMHeadModel (XLM model)
......@@ -273,10 +278,11 @@ class AutoModelWithLMHead(object):
return XLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
"'xlm', 'roberta','ctrl'".format(pretrained_model_name_or_path))
class AutoModelForSequenceClassification(object):
......
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch CTRL model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import json
import logging
import math
import os
import sys
from io import open
import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/seqlen256_v1.bin"}
def angle_defn(pos, i, d_model_size):
angle_rates = 1 / torch.pow(10000, (2 * (i//2)) / d_model_size)
return pos * angle_rates
def positional_encoding(position, d_model_size, dtype):
# create the sinusoidal pattern for the positional encoding
angle_rads = (angle_defn(torch.arange(position, dtype=dtype).unsqueeze(1),
torch.arange(d_model_size, dtype=dtype).unsqueeze(0),
d_model_size))
sines = torch.sin(angle_rads[:, 0::2])
cosines = torch.cos(angle_rads[:, 1::2])
pos_encoding = torch.cat([sines, cosines], dim=-1)
return pos_encoding
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
# calculate attention
matmul_qk = torch.matmul(q, k.permute(0,1,3,2))
dk = k.shape[-1]
scaled_attention_logits = matmul_qk / np.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e4)
if attention_mask is not None:
# Apply the attention mask
scaled_attention_logits = scaled_attention_logits + attention_mask
attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
# Mask heads if we want to
if head_mask is not None:
attention_weights = attention_weights * head_mask
output = torch.matmul(attention_weights, v)
return output, attention_weights
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model_size, num_heads, output_attentions=False):
super(MultiHeadAttention, self).__init__()
self.output_attentions = output_attentions
self.num_heads = num_heads
self.d_model_size = d_model_size
self.depth = int(d_model_size / self.num_heads)
self.Wq = torch.nn.Linear(d_model_size, d_model_size)
self.Wk = torch.nn.Linear(d_model_size, d_model_size)
self.Wv = torch.nn.Linear(d_model_size, d_model_size)
self.dense = torch.nn.Linear(d_model_size, d_model_size)
def split_into_heads(self, x, batch_size):
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute([0, 2, 1, 3])
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None):
batch_size = q.shape[0]
q = self.Wq(q)
k = self.Wk(k)
v = self.Wv(v)
q = self.split_into_heads(q, batch_size)
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k, v))
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = output[0].permute([0, 2, 1, 3])
attn = output[1]
original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
output = self.dense(original_size_attention)
outputs = (output, present)
if self.output_attentions:
outputs = outputs + (attn,)
return outputs
def point_wise_feed_forward_network(d_model_size, dff):
return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff),
torch.nn.ReLU(),
torch.nn.Linear(dff, d_model_size))
class EncoderLayer(torch.nn.Module):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, output_attentions=False):
super(EncoderLayer, self).__init__()
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, output_attentions)
self.ffn = point_wise_feed_forward_network(d_model_size, dff)
self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
self.layernorm2 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate)
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(normed, normed, normed, mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output)
out1 = x + attn_output
out2 = self.layernorm2(out1)
ffn_output = self.ffn(out2)
ffn_output = self.dropout2(ffn_output)
out2 = out1 + ffn_output
outputs = (out2,) + attn_outputs[1:]
return outputs
class CTRLPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = CTRLConfig
pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
def _init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
CTRL_START_DOCSTRING = r""" CTRL model was proposed in
`CTRL: A Conditional Transformer Language Model for Controllable Generation`_
by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
corpus of ~140 GB of text data with the first token reserved as a control code (such as Links, Books, Wikipedia etc.).
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`CTRL: A Conditional Transformer Language Model for Controllable Generation`:
https://www.github.com/salesforce/ctrl
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~transformers.CTRLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
CTRL_INPUTS_DOCSTRING = r""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
CTRL is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**past**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
CTRL_START_DOCSTRING, CTRL_INPUTS_DOCSTRING)
class CTRLModel(CTRLPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = CTRLModel.from_pretrained('ctrl')
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(CTRLModel, self).__init__(config)
self.output_hidden_states = config.output_hidden_states
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
self.output_attentions = config.output_attentions
self.w = nn.Embedding(config.vocab_size, config.n_embd)
self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([EncoderLayer(config.n_embd,
config.n_head,
config.dff,
config.resid_pdrop,
config.output_attentions) for _ in range(config.n_layer)])
self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.w = self._get_resized_embeddings(self.w, new_num_tokens)
return self.w
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Attention mask.
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
token_type_embeds = self.w(token_type_ids)
token_type_embeds *= np.sqrt(self.d_model_size)
else:
token_type_embeds = 0
position_ids = position_ids.view(-1, input_shape[-1])
inputs_embeds = self.w(input_ids)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_ids.shape[-1]
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
inputs_embeds *= np.sqrt(self.d_model_size)
pos_embeds = self.pos_encoding[position_ids, :].to(inputs_embeds.device)
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states)
output_shape = input_shape + (inputs_embeds.size(-1),)
presents = ()
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = h(hidden_states,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i])
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states)
hidden_states = hidden_states.view(*output_shape)
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs
@add_start_docstrings("""The CTRL Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """, CTRL_START_DOCSTRING, CTRL_INPUTS_DOCSTRING)
class CTRLLMHeadModel(CTRLPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import torch
from transformers import CTRLTokenizer, CTRLLMHeadModel
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = CTRLLMHeadModel.from_pretrained('ctrl')
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(CTRLLMHeadModel, self).__init__(config)
self.transformer = CTRLModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head, self.transformer.w)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None):
transformer_outputs = self.transformer(input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
......@@ -170,7 +170,7 @@ class Attention(nn.Module):
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
w = w * b + -1e9 * (1 - b)
w = w * b + - 1e4 * (1 - b)
if attention_mask is not None:
# Apply the attention mask
......
......@@ -172,7 +172,8 @@ class RobertaModel(BertModel):
if input_ids[:, 0].sum().item() != 0:
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding.")
"Please specify add_special_tokens=True in your tokenize.encode()"
"or tokenizer.convert_tokens_to_ids().")
return super(RobertaModel, self).forward(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
......
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 CTRL model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
import sys
from io import open
import numpy as np
import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def angle_defn(pos, i, d_model_size):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
return pos * angle_rates
def positional_encoding(position, d_model_size):
# create the sinusoidal pattern for the positional encoding
angle_rads = angle_defn(np.arange(position)[:, np.newaxis],
np.arange(d_model_size)[np.newaxis, :],
d_model_size)
sines = np.sin(angle_rads[:, 0::2])
cosines = np.cos(angle_rads[:, 1::2])
# pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)
pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1), dtype=tf.float32)
return pos_encoding
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
# calculate attention
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(shape_list(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e4)
if attention_mask is not None:
# Apply the attention mask
scaled_attention_logits = scaled_attention_logits + attention_mask
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
# Mask heads if we want to
if head_mask is not None:
attention_weights = attention_weights * head_mask
output = tf.matmul(attention_weights, v)
return output, attention_weights
class TFMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
super(TFMultiHeadAttention, self).__init__(**kwargs)
self.output_attentions = output_attentions
self.num_heads = num_heads
self.d_model_size = d_model_size
self.depth = int(d_model_size / self.num_heads)
self.Wq = tf.keras.layers.Dense(d_model_size, name='Wq')
self.Wk = tf.keras.layers.Dense(d_model_size, name='Wk')
self.Wv = tf.keras.layers.Dense(d_model_size, name='Wv')
self.dense = tf.keras.layers.Dense(d_model_size, name='dense')
def split_into_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask = inputs
batch_size = q.shape[0]
q = self.Wq(q)
k = self.Wk(k)
v = self.Wv(v)
q = self.split_into_heads(q, batch_size)
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1)
k = tf.concat((past_key, k), dim=-2)
v = tf.concat((past_value, v), dim=-2)
present = tf.stack((k, v), axis=1)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
attn = output[1]
original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
output = self.dense(original_size_attention)
outputs = (output, present)
if self.output_attentions:
outputs = outputs + (attn,)
return outputs
def point_wise_feed_forward_network(d_model_size, dff, name=""):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu', name="0"),
tf.keras.layers.Dense(d_model_size, name="2")
], name="ffn")
class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs):
super(TFEncoderLayer, self).__init__(**kwargs)
self.multi_head_attention = TFMultiHeadAttention(d_model_size,
num_heads,
output_attentions,
name="multi_head_attention")
self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn")
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False):
x, mask, layer_past, attention_mask, head_mask = inputs
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention([normed, normed, normed, mask, layer_past,
attention_mask, head_mask], training=training)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output, training=training)
out1 = x + attn_output
out2 = self.layernorm2(out1)
ffn_output = self.ffn(out2)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = out1 + ffn_output
outputs = (out2,) + attn_outputs[1:]
return outputs
class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFCTRLMainLayer, self).__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
self.output_attentions = config.output_attentions
self.w = TFSharedEmbeddings(config.vocab_size,
config.n_embd,
initializer_range=config.initializer_range,
name="w")
self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFEncoderLayer(config.n_embd,
config.n_head,
config.dff,
config.resid_pdrop,
config.layer_norm_epsilon,
config.output_attentions,
name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
raise NotImplementedError
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
past = inputs.get('past', past)
attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = shape_list(past[0][0])[-2]
if position_ids is None:
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1])
# Attention mask.
if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0
else:
attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_layers
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.w(token_type_ids, mode='embedding')
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
else:
token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs_embeds = self.w(input_ids, mode='embedding')
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
pos_embeds = tf.gather(self.pos_encoding, position_ids)
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states, training=training)
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = ()
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape)
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs
class TFCTRLPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class = CTRLConfig
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
load_pt_weights = load_ctrl_pt_weights_in_tf2
CTRL_START_DOCSTRING = r""" CTRL model was proposed in
`CTRL: A Conditional Transformer Language Model for Controllable Generation`_
by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
corpus of ~140 GB of text data with the first token reserved as a control code (such as Links, Books, Wikipedia etc.).
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`CTRL: A Conditional Transformer Language Model for Controllable Generation`:
https://www.github.com/salesforce/ctrl
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~transformers.CTRLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
CTRL_INPUTS_DOCSTRING = r""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
CTRL is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**past**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
CTRL_START_DOCSTRING, CTRL_INPUTS_DOCSTRING)
class TFCTRLModel(TFCTRLPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**past**:
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import CTRLTokenizer, TFCTRLModel
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = TFCTRLModel.from_pretrained('ctrl')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config, *inputs, **kwargs):
super(TFCTRLModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFCTRLMainLayer(config, name='transformer')
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
class TFCTRLLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super(TFCTRLLMHead, self).__init__(**kwargs)
self.vocab_size = config.vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,),
initializer='zeros',
trainable=True,
name='bias')
super(TFCTRLLMHead, self).build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
@add_start_docstrings("""The CTRL Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """, CTRL_START_DOCSTRING, CTRL_INPUTS_DOCSTRING)
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import torch
from transformers import CTRLTokenizer, TFCTRLLMHeadModel
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = TFCTRLLMHeadModel.from_pretrained('ctrl')
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFCTRLLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFCTRLMainLayer(config, name='transformer')
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
return outputs # lm_logits, presents, (all hidden_states), (attentions)
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import pytest
import shutil
import pdb
from transformers import is_torch_available
if is_torch_available():
from transformers import (CTRLConfig, CTRLModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel)
else:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
class CTRLModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
class CTRLModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_token_type_ids=True,
use_input_mask=True,
use_labels=True,
use_mc_token_ids=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_token_type_ids = use_token_type_ids
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.use_mc_token_ids = use_mc_token_ids
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
mc_token_ids = None
if self.use_mc_token_ids:
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = CTRLConfig(
vocab_size_or_config_json_file=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_ctrl_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = CTRLModel(config=config)
model.eval()
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids)
sequence_output, presents = model(input_ids)
result = {
"sequence_output": sequence_output,
"presents": presents,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(result["presents"]), config.n_layer)
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = CTRLLMHeadModel(config)
model.eval()
loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
result = {
"loss": loss,
"lm_logits": lm_logits
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, head_mask, token_type_ids,
mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'head_mask': head_mask
}
return config, inputs_dict
def setUp(self):
self.model_tester = CTRLModelTest.CTRLModelTester(self)
self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_ctrl_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_ctrl_model(*config_and_inputs)
def test_ctrl_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
......@@ -71,6 +71,8 @@ class TFCommonTestCases:
if not is_torch_available():
return
import torch
import numpy as np
import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -79,12 +81,23 @@ class TFCommonTestCases:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa (architecture similar)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long))
for name, key in inputs_dict.items())
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict)
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from transformers import CTRLConfig, is_tf_available
if is_tf_available():
import tensorflow as tf
from transformers.modeling_tf_ctrl import (TFCTRLModel, TFCTRLLMHeadModel,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
class TFCTRLModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_token_type_ids=True,
use_input_mask=True,
use_labels=True,
use_mc_token_ids=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_token_type_ids = use_token_type_ids
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.use_mc_token_ids = use_mc_token_ids
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
mc_token_ids = None
if self.use_mc_token_ids:
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = CTRLConfig(
vocab_size_or_config_json_file=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
return config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, token_labels, choice_labels
def create_and_check_ctrl_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFCTRLModel(config=config)
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
sequence_output = model(inputs)[0]
inputs = [input_ids, None, input_mask] # None is the input for 'past'
sequence_output = model(inputs)[0]
sequence_output = model(input_ids)[0]
result = {
"sequence_output": sequence_output.numpy(),
}
self.parent.assertListEqual(
list(result["sequence_output"].shape),
[self.batch_size, self.seq_length, self.hidden_size])
def create_and_check_ctrl_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFCTRLLMHeadModel(config=config)
inputs = {'input_ids': input_ids,
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
prediction_scores = model(inputs)[0]
result = {
"prediction_scores": prediction_scores.numpy(),
}
self.parent.assertListEqual(
list(result["prediction_scores"].shape),
[self.batch_size, self.seq_length, self.vocab_size])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, head_mask, token_type_ids,
mc_token_ids, sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict
def setUp(self):
self.model_tester = TFCTRLModelTest.TFCTRLModelTester(self)
self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_ctrl_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_ctrl_model(*config_and_inputs)
def test_ctrl_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_ctrl_lm_head(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
......@@ -222,7 +222,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_gpt2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
import json
from io import open
from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases
class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = CTRLTokenizer
def setUp(self):
super(CTRLTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ['adapt', 're@@', 'a@@', 'apt', 'c@@', 't', '<unk>']
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", 'a p', 'ap t</w>', 'r e', 'a d', 'ad apt</w>', '']
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = u"adapt react readapt apt"
output_text = u"adapt react readapt apt"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "adapt react readapt apt"
bpe_tokens = 'adapt re@@ a@@ c@@ t re@@ adapt apt'.split()
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,7 @@ import logging
from .tokenization_bert import BertTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_xlnet import XLNetTokenizer
from .tokenization_xlm import XLMTokenizer
......@@ -45,6 +46,7 @@ class AutoTokenizer(object):
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
......@@ -67,6 +69,7 @@ class AutoTokenizer(object):
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
......@@ -114,7 +117,8 @@ class AutoTokenizer(object):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment